blob: eb7666ea0c42a6ec7f518a273b8b30f140f54209 [file] [log] [blame]
# Copyright 2021 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for default_clustering_params."""
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from clustering import central_privacy_utils
from clustering import clustering_params
from clustering import default_clustering_params
from clustering import test_utils
class ClusteringParamTest(parameterized.TestCase):
@parameterized.named_parameters(
('basic', 100000, 90000, 10, 25.8, 567, 189, 20),
('zero_std_dev', 100000, 90000, 10, 0, 3, 1, 20),
('min_num_points_min', 1000, 900, 100, 25.8, 12, 4, 20),
('negative_private_count', 100000, -100, 10, 25.8, 3, 1, 20),
)
@mock.patch.object(central_privacy_utils, 'get_private_count', autospec=True)
def test_default_tree_param(self, points, returned_private_count, k,
gaussian_std_dev,
expected_min_num_points_in_branching_node,
expected_min_num_points_in_node,
expected_max_depth, mock_private_count):
dim = 10
radius = 4.3
mock_private_count.return_value = returned_private_count
data = clustering_params.Data(np.ones(shape=(points, dim)), radius=radius)
pcalc = test_utils.get_test_privacy_calculator(
gaussian_std_dev=gaussian_std_dev, sensitivity=radius)
(tree_param, private_count) = default_clustering_params.default_tree_param(
k, data, pcalc, expected_max_depth)
self.assertEqual(tree_param.max_depth, expected_max_depth)
mock_private_count.assert_called_once_with(
nonprivate_count=points, count_privacy_param=pcalc.count_privacy_param)
self.assertEqual(private_count, returned_private_count)
self.assertEqual(tree_param.min_num_points_in_node,
expected_min_num_points_in_node)
self.assertEqual(tree_param.min_num_points_in_branching_node,
expected_min_num_points_in_branching_node)
if __name__ == '__main__':
absltest.main()