# Copyright 2021 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for lsh_tree."""

import typing
from unittest import mock

from absl.testing import absltest
import numpy as np
from scipy import stats

from clustering import clustering_params
from clustering import lsh
from clustering import lsh_tree
from clustering import test_utils


def get_test_origin_points(nonprivate_count=5, dim=10):
  """Points with defaults for parameters not needed for the test."""
  return np.zeros((nonprivate_count, dim))


def get_test_sim_hash(dim=10, max_hash_len=1):
  """SimHash with defaults for parameters not needed for the test."""
  return lsh.SimHash(dim=dim, max_hash_len=max_hash_len)


class TestLshTreeNode(lsh_tree.LshTreeNode):
  """Test node for testing without the real children hashing logic.

  This test implementation always returns a fraction of the points with the next
  hash_prefix containing 0, and the rest with 1, setting the private count equal
  to the real count + 1. Equality is checked by just checking the hash_prefix
  and private_count.
  """
  # Fraction of the nonprivate_points to add a 0 to the hash_prefix.
  frac_zero: float

  def __init__(self, *args, frac_zero: float = 0.5, **kwargs):
    """Initializes test node.

    Args:
      *args: extra arguments for the parent class
      frac_zero: fraction of the nonprivate_points to add a 0 to the hash_prefix
      **kwargs: extra arguments for the parent class
    """
    super(TestLshTreeNode, self).__init__(*args, **kwargs)
    self.frac_zero = frac_zero

  def get_private_count(self) -> int:
    """Returns a fake private count."""
    self.private_count = len(self.nonprivate_points) + 1
    return self.private_count

  def children(self) -> typing.List[lsh_tree.LshTreeNode]:
    """Returns fake children for this node."""
    cutoff = int(len(self.nonprivate_points) * self.frac_zero)
    return [
        TestLshTreeNode(
            self.hash_prefix + '0',
            self.nonprivate_points[:cutoff],
            self.coreset_param,
            self.sim_hash,
            private_count=cutoff + 1,
            frac_zero=self.frac_zero),
        TestLshTreeNode(
            self.hash_prefix + '1',
            self.nonprivate_points[cutoff:],
            self.coreset_param,
            self.sim_hash,
            private_count=len(self.nonprivate_points) - cutoff + 1,
            frac_zero=self.frac_zero)
    ]

  def __eq__(self, other):
    """Returns whether hash_prefix and private_count are the same for tests."""
    if not isinstance(other, TestLshTreeNode):
      return False
    return self.hash_prefix == other.hash_prefix and (self.private_count
                                                      == other.private_count)


class LshTreeTest(absltest.TestCase):

  @mock.patch.object(stats.dlaplace, 'rvs', return_value=-5, autospec=True)
  def test_get_private_count_basic(self, mock_dlaplace_fn):
    nonprivate_count = 30
    nonprivate_points = get_test_origin_points(
        nonprivate_count=nonprivate_count)
    coreset_param = test_utils.get_test_coreset_param(
        epsilon=5, frac_sum=0.2, frac_group_count=0.8, max_depth=9)
    sim_hash = get_test_sim_hash()
    lsh_tree_node = lsh_tree.LshTreeNode(
        hash_prefix='',
        nonprivate_points=nonprivate_points,
        coreset_param=coreset_param,
        sim_hash=sim_hash)
    self.assertEqual(lsh_tree_node.get_private_count(), 25)
    mock_dlaplace_fn.assert_called_once_with(0.4)

  def test_get_private_count_cache(self):
    nonprivate_count = 30
    nonprivate_points = get_test_origin_points(
        nonprivate_count=nonprivate_count)
    coreset_param = test_utils.get_test_coreset_param(epsilon=0.01)
    sim_hash = get_test_sim_hash()
    lsh_tree_node = lsh_tree.LshTreeNode(
        hash_prefix='',
        nonprivate_points=nonprivate_points,
        coreset_param=coreset_param,
        sim_hash=sim_hash)

    first_private_count = lsh_tree_node.get_private_count()
    self.assertEqual(first_private_count, lsh_tree_node.get_private_count())

  def test_get_children(self):
    hash_prefix, dim, max_hash_len = '0', 5, 2
    datapoints = np.array([[1.5, 0, 1, 0.5, 0], [1, 1, 0, 0.7, 0.1],
                           [0.8, 0.1, 1, 0.2, 0.4], [0.1, 0.5, 0.3, 0.7, 0.8],
                           [-0.5, 0.1, -0.3, -0.4, 0.2]])
    # Returns children regardless of whether the node should branch. The
    # filtering in the algorithm is done after.
    coreset_param = test_utils.get_test_coreset_param(max_depth=max_hash_len)
    projection_vectors = np.array([[0, 1, 1, -1, 0], [1, 0, -1, 0, 0]])
    sh = lsh.SimHash(dim, max_hash_len, projection_vectors)
    node = lsh_tree.LshTreeNode(hash_prefix, datapoints, coreset_param, sh)
    children = node.children()

    self.assertSameElements([child.hash_prefix for child in children],
                            ['00', '01'])
    for child in children:
      self.assertEqual(child.coreset_param, coreset_param)
      self.assertEqual(child.sim_hash, sh)
      if child.hash_prefix == '00':
        self.assertTrue((child.nonprivate_points == datapoints[[0, 1]]).all())
      if child.hash_prefix == '01':
        self.assertTrue((child.nonprivate_points == datapoints[[2, 3,
                                                                4]]).all())

  def test_get_children_one_empty(self):
    hash_prefix, dim, max_hash_len = '0', 5, 2
    datapoints = np.array([[1.5, 0, 1, 0.5, 0], [1, 1, 0, 0.7, 0.1]])
    # Returns children regardless of whether the node should branch. The
    # filtering in the algorithm is done after.
    coreset_param = test_utils.get_test_coreset_param(max_depth=max_hash_len)
    projection_vectors = np.array([[0, 1, 1, -1, 0], [1, 0, -1, 0, 0]])
    sh = lsh.SimHash(dim, max_hash_len, projection_vectors)
    node = lsh_tree.LshTreeNode(hash_prefix, datapoints, coreset_param, sh)
    children = node.children()

    self.assertSameElements([child.hash_prefix for child in children],
                            ['00', '01'])
    for child in children:
      self.assertEqual(child.coreset_param, coreset_param)
      self.assertEqual(child.sim_hash, sh)
      if child.hash_prefix == '00':
        self.assertTrue((child.nonprivate_points == datapoints[[0, 1]]).all())
      if child.hash_prefix == '01':
        self.assertEmpty(child.nonprivate_points)

  def test_get_children_error(self):
    hash_prefix, dim, max_hash_len = '00', 5, 2
    datapoints = np.array([[1.5, 0, 1, 0.5, 0], [1, 1, 0, 0.7, 0.1]])
    # Returns children regardless of whether the node should branch. The
    # filtering in the algorithm is done after.
    coreset_param = test_utils.get_test_coreset_param(max_depth=max_hash_len)
    projection_vectors = np.array([[0, 1, 1, -1, 0], [1, 0, -1, 0, 0]])
    sh = lsh.SimHash(dim, max_hash_len, projection_vectors)
    node = lsh_tree.LshTreeNode(hash_prefix, datapoints, coreset_param, sh)

    with self.assertRaises(ValueError):
      node.children()

  def test_filter_branching_nodes_too_few_points(self):
    sim_hash = get_test_sim_hash()
    # private_count, not the nonprivate_count, should be used for the check.
    level: lsh_tree.LshTreeLevel = [
        lsh_tree.LshTreeNode(
            '0',
            get_test_origin_points(nonprivate_count=15),
            test_utils.get_test_coreset_param(
                min_num_points_in_branching_node=10),
            sim_hash,
            private_count=1),
    ]
    self.assertEmpty(lsh_tree.LshTree.filter_branching_nodes(level))

  def test_filter_branching_nodes_enough_points(self):
    sim_hash = get_test_sim_hash()
    level: lsh_tree.LshTreeLevel = [
        lsh_tree.LshTreeNode(
            '0',
            get_test_origin_points(nonprivate_count=15),
            test_utils.get_test_coreset_param(
                min_num_points_in_branching_node=10),
            sim_hash,
            private_count=20),
    ]
    self.assertSequenceEqual(
        lsh_tree.LshTree.filter_branching_nodes(level), level)

  def test_get_next_level_empty_list(self):
    self.assertEmpty(lsh_tree.LshTree.get_next_level([]))

  def test_get_next_level(self):
    sim_hash = get_test_sim_hash()
    coreset_param = test_utils.get_test_coreset_param(
        min_num_points_in_branching_node=10, min_num_points_in_node=5)
    level: lsh_tree.LshTreeLevel = [
        TestLshTreeNode(
            '0',
            get_test_origin_points(nonprivate_count=16),
            coreset_param,
            sim_hash,
            private_count=20),
    ]
    expected_next_level = [
        TestLshTreeNode(
            '00',
            get_test_origin_points(nonprivate_count=8),
            coreset_param,
            sim_hash,
            private_count=9),
        TestLshTreeNode(
            '01',
            get_test_origin_points(nonprivate_count=8),
            coreset_param,
            sim_hash,
            private_count=9),
    ]
    branching_nodes = lsh_tree.LshTree.filter_branching_nodes(level)
    self.assertSequenceEqual(
        lsh_tree.LshTree.get_next_level(branching_nodes), expected_next_level)

  def test_get_next_level_filters_children_node(self):
    sim_hash = get_test_sim_hash()
    coreset_param = test_utils.get_test_coreset_param(
        min_num_points_in_branching_node=10, min_num_points_in_node=9)
    level: lsh_tree.LshTreeLevel = [
        # The children test nodes have a private count of 6, which is less than
        # min_num_points_in_node.
        TestLshTreeNode(
            '0',
            get_test_origin_points(nonprivate_count=10),
            coreset_param,
            sim_hash,
            private_count=11),
        # The children test nodes have a private count of 3 and 9, only the node
        # with 9 should be in the result.
        TestLshTreeNode(
            '1',
            get_test_origin_points(nonprivate_count=10),
            coreset_param,
            sim_hash,
            private_count=11,
            frac_zero=0.2),
    ]
    expected_next_level = [
        TestLshTreeNode(
            '11',
            get_test_origin_points(nonprivate_count=8),
            coreset_param,
            sim_hash,
            private_count=9),
    ]
    branching_nodes = lsh_tree.LshTree.filter_branching_nodes(level)
    self.assertSequenceEqual(
        lsh_tree.LshTree.get_next_level(branching_nodes), expected_next_level)

  def test_root_node(self):
    nonprivate_points = [[1, 2, 1], [0.4, 0.2, 0.8], [3, 0, 3]]
    data = clustering_params.Data(nonprivate_points, radius=4.3)
    coreset_param = test_utils.get_test_coreset_param(radius=4.3, max_depth=20)
    root = lsh_tree.root_node(data, coreset_param)
    self.assertEqual(root.hash_prefix, '')
    self.assertSequenceEqual(root.nonprivate_points, nonprivate_points)
    self.assertEqual(root.coreset_param, coreset_param)
    self.assertEqual(root.sim_hash.dim, 3)
    self.assertEqual(root.sim_hash.max_hash_len, 20)
    self.assertIsNotNone(root.private_count)

  def test_root_node_provide_private_count(self):
    nonprivate_points = [[1, 2, 1], [0.4, 0.2, 0.8], [3, 0, 3]]
    data = clustering_params.Data(nonprivate_points, radius=4.3)
    coreset_param = test_utils.get_test_coreset_param(radius=4.3, max_depth=20)
    root = lsh_tree.root_node(data, coreset_param, private_count=10)
    self.assertEqual(root.hash_prefix, '')
    self.assertSequenceEqual(root.nonprivate_points, nonprivate_points)
    self.assertEqual(root.coreset_param, coreset_param)
    self.assertEqual(root.sim_hash.dim, 3)
    self.assertEqual(root.sim_hash.max_hash_len, 20)
    self.assertEqual(root.private_count, 10)

  def test_lsh_tree_empty_root_errors(self):
    test_root = lsh_tree.LshTreeNode(
        '0',
        get_test_origin_points(nonprivate_count=15),
        test_utils.get_test_coreset_param(),
        get_test_sim_hash(),
        private_count=0)
    with self.assertRaises(ValueError):
      lsh_tree.LshTree(test_root)

  def test_lsh_tree_negative_count_root_errors(self):
    test_root = lsh_tree.LshTreeNode(
        '0',
        get_test_origin_points(nonprivate_count=15),
        test_utils.get_test_coreset_param(),
        get_test_sim_hash(),
        private_count=-10)
    with self.assertRaises(ValueError):
      lsh_tree.LshTree(test_root)

  def test_lsh_tree(self):
    # Test tree:
    # Nodes are nonprivate count + 1.
    # Branches to the left are 0, to the right are 1.
    # Nodes in parentheses are filtered out.
    #           64+1
    #          /    \
    #      8+1      56+1
    #     /  \      /   \
    # (1+1)   7+1  7+1   49+1
    #                   /   \
    #                (6+1)  43+1
    nonprivate_count = 64
    sh = get_test_sim_hash()
    cp = test_utils.get_test_coreset_param(
        min_num_points_in_node=8,
        min_num_points_in_branching_node=9,
        max_depth=3)
    test_root = TestLshTreeNode(
        '', get_test_origin_points(nonprivate_count), cp, sh, frac_zero=0.125)
    expected_tree = {
        0: [TestLshTreeNode('', get_test_origin_points(64), cp, sh)],
        1: [
            TestLshTreeNode('0', get_test_origin_points(8), cp, sh),
            TestLshTreeNode('1', get_test_origin_points(56), cp, sh)
        ],
        2: [
            TestLshTreeNode('01', get_test_origin_points(7), cp, sh),
            TestLshTreeNode('10', get_test_origin_points(7), cp, sh),
            TestLshTreeNode('11', get_test_origin_points(49), cp, sh)
        ],
        3: [TestLshTreeNode('111', get_test_origin_points(43), cp, sh)],
    }
    tree = lsh_tree.LshTree(test_root)
    self.assertEqual(tree.tree, expected_tree)

  def test_lsh_tree_branching_node_becomes_leaf(self):
    # Test tree:
    # Nodes are nonprivate count + 1.
    # Branches to the left are 0, to the right are 1.
    # Nodes in parentheses are filtered out.
    #             64+1
    #          /       \
    #      32+1          32+1
    #     /   \         /   \
    # (16+1) (16+1)  (16+1) (16+1)
    nonprivate_count = 64
    sh = get_test_sim_hash()
    cp = test_utils.get_test_coreset_param(
        min_num_points_in_node=20,
        min_num_points_in_branching_node=30,
        max_depth=5)
    test_root = TestLshTreeNode(
        '', get_test_origin_points(nonprivate_count), cp, sh, frac_zero=0.5)
    expected_tree = {
        0: [TestLshTreeNode('', get_test_origin_points(64), cp, sh)],
        1: [
            TestLshTreeNode('0', get_test_origin_points(32), cp, sh),
            TestLshTreeNode('1', get_test_origin_points(32), cp, sh),
        ]
    }
    tree = lsh_tree.LshTree(test_root)
    self.assertEqual(tree.tree, expected_tree)

  def test_lsh_tree_leaves(self):
    # Test tree:
    # Nodes are nonprivate count + 1.
    # Branches to the left are 0, to the right are 1.
    # Nodes in parentheses are filtered out.
    #           64+1
    #          /    \
    #      8+1      56+1
    #     /  \      /   \
    # (1+1)   7+1  7+1   49+1
    #                   /   \
    #                (6+1)  43+1
    nonprivate_count = 64
    sh = get_test_sim_hash()
    cp = test_utils.get_test_coreset_param(
        min_num_points_in_node=8,
        min_num_points_in_branching_node=9,
        max_depth=3)
    test_root = TestLshTreeNode(
        '', get_test_origin_points(nonprivate_count), cp, sh, frac_zero=0.125)
    expected_leaves = [
        TestLshTreeNode('01', get_test_origin_points(7), cp, sh),
        TestLshTreeNode('10', get_test_origin_points(7), cp, sh),
        TestLshTreeNode('111', get_test_origin_points(43), cp, sh)
    ]
    tree = lsh_tree.LshTree(test_root)
    self.assertEqual(tree.leaves, expected_leaves)

  def test_lsh_tree_leaves_branching_node_becomes_leaf(self):
    # Test tree:
    # Nodes are nonprivate count + 1.
    # Branches to the left are 0, to the right are 1.
    # Nodes in parentheses are filtered out.
    #             64+1
    #          /       \
    #      32+1          32+1
    #     /   \         /   \
    # (16+1) (16+1)  (16+1) (16+1)
    nonprivate_count = 64
    sh = get_test_sim_hash()
    cp = test_utils.get_test_coreset_param(
        min_num_points_in_node=20,
        min_num_points_in_branching_node=30,
        max_depth=5)
    test_root = TestLshTreeNode(
        '', get_test_origin_points(nonprivate_count), cp, sh, frac_zero=0.5)
    expected_leaves = [
        TestLshTreeNode('0', get_test_origin_points(32), cp, sh),
        TestLshTreeNode('1', get_test_origin_points(32), cp, sh),
    ]
    tree = lsh_tree.LshTree(test_root)
    self.assertEqual(tree.leaves, expected_leaves)


if __name__ == '__main__':
  absltest.main()
