blob: 1737fbc41933fc7c60f229240acae91692dc0bcf [file] [log] [blame]
# Copyright 2021 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for clustering_params."""
from absl.testing import absltest
import numpy as np
from clustering import clustering_params
class ClusteringParamTest(absltest.TestCase):
def test_privacy_param_defaults(self):
privacy_param = clustering_params.DifferentialPrivacyParam()
self.assertEqual(privacy_param.epsilon, 1.0)
self.assertEqual(privacy_param.delta, 1e-6)
self.assertEqual(privacy_param.privacy_model,
clustering_params.PrivacyModel.CENTRAL)
def test_privacy_budget_split_defaults(self):
privacy_budget_split = clustering_params.PrivacyBudgetSplit()
self.assertEqual(privacy_budget_split.frac_sum, 0.8)
self.assertEqual(privacy_budget_split.frac_group_count, 0.2)
def test_privacy_budget_split_invalid(self):
with self.assertRaises(
ValueError,
msg="The provided privacy budget split (1.6) was greater than 1.0."):
clustering_params.PrivacyBudgetSplit(frac_sum=0.7, frac_group_count=0.8)
def test_tree_param(self):
tree_param = clustering_params.TreeParam(
min_num_points_in_branching_node=4,
min_num_points_in_node=2,
max_depth=5)
self.assertEqual(tree_param.min_num_points_in_branching_node, 4)
self.assertEqual(tree_param.min_num_points_in_node, 2)
self.assertEqual(tree_param.max_depth, 5)
def test_error_tree_param(self):
with self.assertRaises(ValueError):
clustering_params.TreeParam(
min_num_points_in_branching_node=4,
min_num_points_in_node=0,
max_depth=5)
with self.assertRaises(ValueError):
clustering_params.TreeParam(
min_num_points_in_branching_node=4,
min_num_points_in_node=-2,
max_depth=5)
with self.assertRaises(ValueError):
clustering_params.TreeParam(
min_num_points_in_branching_node=4,
min_num_points_in_node=20,
max_depth=5)
def test_data(self):
(points, dim) = (10, 3)
data = clustering_params.Data(np.ones(shape=(points, dim)), radius=1.0)
self.assertEqual(data.num_points, points)
self.assertEqual(data.dim, dim)
self.assertEqual(data.radius, 1.0)
def test_data_label_unequal_length(self):
points, dim = 10, 3
datapoints = np.zeros(shape=(points, dim))
labels = np.ones(points - 1, dtype=int)
with self.assertRaises(ValueError):
clustering_params.Data(datapoints, radius=1.0, labels=labels)
def test_clip_by_radius(self):
datapoints = np.array([[3., 2., 4.], [1., 2., 3.]])
data = clustering_params.Data(datapoints, radius=10.0)
points_to_clip = np.array([[0., 0., 0., 0.], [1., 2., 3., 4.],
[5., 6., 7., 8.], [9., 10., 11., 12.],
[13., 14., 15., 16.]])
clipped_datapoints = data.clip_by_radius(points_to_clip)
self.assertLen(clipped_datapoints, 5)
self.assertSequenceAlmostEqual(clipped_datapoints[0], [0., 0., 0., 0.])
self.assertSequenceAlmostEqual(clipped_datapoints[1], [1., 2., 3., 4.])
self.assertSequenceAlmostEqual(
clipped_datapoints[2], [3.79049022, 4.54858826, 5.30668631, 6.06478435])
self.assertSequenceAlmostEqual(
clipped_datapoints[3], [4.26162351, 4.73513724, 5.20865096, 5.68216469])
self.assertSequenceAlmostEqual(
clipped_datapoints[3], [4.26162351, 4.73513724, 5.20865096, 5.68216469])
self.assertSequenceAlmostEqual(
clipped_datapoints[4], [4.46949207, 4.81329915, 5.15710623, 5.50091331])
def test_clip_by_radius_default_to_self(self):
datapoints = np.array([[0., 0., 0., 0.], [1., 2., 3., 4.], [5., 6., 7., 8.],
[9., 10., 11., 12.], [13., 14., 15., 16.]])
data = clustering_params.Data(datapoints, radius=10.0)
clipped_datapoints = data.clip_by_radius()
self.assertLen(clipped_datapoints, 5)
self.assertSequenceAlmostEqual(clipped_datapoints[0], [0., 0., 0., 0.])
self.assertSequenceAlmostEqual(clipped_datapoints[1], [1., 2., 3., 4.])
self.assertSequenceAlmostEqual(
clipped_datapoints[2], [3.79049022, 4.54858826, 5.30668631, 6.06478435])
self.assertSequenceAlmostEqual(
clipped_datapoints[3], [4.26162351, 4.73513724, 5.20865096, 5.68216469])
self.assertSequenceAlmostEqual(
clipped_datapoints[4], [4.46949207, 4.81329915, 5.15710623, 5.50091331])
def test_privacy_calculator_multiplier(self):
multiplier = clustering_params.PrivacyCalculatorMultiplier(
gaussian_std_dev_multiplier=4.2, laplace_param_multiplier=5.1)
alpha = 3.0
sensitivity = 1.4
std_dev = multiplier.get_gaussian_std_dev(alpha, sensitivity)
self.assertEqual(std_dev, 17.64)
self.assertEqual(multiplier.get_alpha(std_dev, sensitivity), alpha)
self.assertEqual(multiplier.get_laplace_param(alpha), 1.0 / 15.3)
if __name__ == "__main__":
absltest.main()