blob: 5975d3e0a7ed993ef896893d339309d2f9a5fc27 [file] [log] [blame]
# Copyright 2022, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for mechanism_calibration."""
from absl.testing import absltest
from absl.testing import parameterized
import attr
import numpy as np
from dp_accounting import dp_event
from dp_accounting import mechanism_calibration
from dp_accounting import privacy_accountant
@attr.define
class MockEvent(dp_event.DpEvent):
param: float
class MockAccountant(privacy_accountant.PrivacyAccountant):
def __init__(self, value_to_epsilon):
super().__init__(
privacy_accountant.NeighboringRelation.ADD_OR_REMOVE_ONE)
self._value = 0.0
self._value_to_epsilon = value_to_epsilon
def _maybe_compose(self, event: dp_event.DpEvent, count: int,
do_compose: bool):
self._value = event.param
def get_epsilon(self, target_delta: float) -> float:
return self._value_to_epsilon(self._value)
class MechanismCalibrationTest(parameterized.TestCase):
@parameterized.parameters(
{'eps_fn': lambda x: x, 'expected': 2.0},
{'eps_fn': lambda x: 4 - x, 'expected': 2.0},
{'eps_fn': np.square, 'expected': np.sqrt(2)},
{'eps_fn': np.cbrt, 'expected': 8.0},
{'eps_fn': lambda x: (x - 5) ** 3 + 2, 'expected': 5},
{'eps_fn': lambda x: np.cos(x / 3) + 2, 'expected': 3 * np.pi / 2},
{'eps_fn': lambda x: np.sin(x - 5) + (x + 3) / 4, 'expected': 5},
{'eps_fn': lambda x: (13 - x) / 4 - np.sin(x - 5), 'expected': 5},
)
def test_basic_inversion(self, eps_fn, expected):
value = mechanism_calibration.calibrate_dp_mechanism(
lambda: MockAccountant(eps_fn), MockEvent, 2, 0,
mechanism_calibration.ExplicitBracketInterval(0, 10), tol=1e-12)
self.assertIsInstance(value, float)
self.assertAlmostEqual(value, expected)
accountant = MockAccountant(eps_fn)
accountant.compose(MockEvent(value))
epsilon = accountant.get_epsilon(0)
self.assertLessEqual(epsilon, 2)
@parameterized.parameters(
{'eps_fn': lambda x: -1 if x < 0 else 1},
{'eps_fn': lambda x: 1 if x < 0 else -1},
{'eps_fn': lambda x: x - 1 if x < 0 else x + 1},
{'eps_fn': lambda x: -2 - x if x < 0 else 2 - x},
{'eps_fn': lambda x: x + 2 if x < 0 else x - 2},
{'eps_fn': lambda x: 1 - x if x < 0 else -1 - x},
)
def test_discontinuous(self, eps_fn):
value = mechanism_calibration.calibrate_dp_mechanism(
lambda: MockAccountant(eps_fn), MockEvent, 0, 0,
mechanism_calibration.ExplicitBracketInterval(-1, 1), tol=1e-12)
self.assertIsInstance(value, float)
self.assertAlmostEqual(value, 0)
accountant = MockAccountant(eps_fn)
accountant.compose(MockEvent(value))
epsilon = accountant.get_epsilon(0)
self.assertLessEqual(epsilon, 0)
@parameterized.parameters(
{'eps_fn': lambda x: x - 2, 'expected_eps': 0},
{'eps_fn': lambda x: x - 2.1, 'expected_eps': -0.1},
{'eps_fn': lambda x: x - 2.9, 'expected_eps': -0.9},
{'eps_fn': lambda x: 2 - x, 'expected_eps': 0},
{'eps_fn': lambda x: 1.9 - x, 'expected_eps': -0.1},
{'eps_fn': lambda x: 1.1 - x, 'expected_eps': -0.9},
)
def test_discrete(self, eps_fn, expected_eps):
value = mechanism_calibration.calibrate_dp_mechanism(
lambda: MockAccountant(eps_fn), MockEvent, 0, 0,
mechanism_calibration.ExplicitBracketInterval(0, 5), discrete=True)
self.assertIsInstance(value, int)
self.assertEqual(value, 2)
accountant = MockAccountant(eps_fn)
accountant.compose(MockEvent(value))
epsilon = accountant.get_epsilon(0)
self.assertAlmostEqual(epsilon, expected_eps)
@parameterized.parameters(
{'epsilon_gap': lambda x: x, 'lower': -1, 'guess': -0.5},
{'epsilon_gap': lambda x: -x, 'lower': -1, 'guess': -0.5},
{'epsilon_gap': lambda x: np.exp(x) - 2, 'lower': 0, 'guess': 0.1},
{'epsilon_gap': lambda x: 1 - np.sqrt(x), 'lower': 0, 'guess': 0.1},
{'epsilon_gap': lambda x: np.log(x) - 20, 'lower': 1, 'guess': 2},
)
def test_search_for_explicit_bracket_interval(
self, epsilon_gap, lower, guess):
lower_value = epsilon_gap(lower)
interval = mechanism_calibration._search_for_explicit_bracket_interval(
mechanism_calibration.LowerEndpointAndGuess(lower, guess), epsilon_gap)
upper_value = epsilon_gap(interval.endpoint_2)
self.assertLessEqual(lower_value * upper_value, 0)
def test_raises_unknown_bracket_interval_type(self):
class UnknownBracketInterval(mechanism_calibration.BracketInterval):
pass
with self.assertRaisesRegex(TypeError, 'Unrecognized bracket_interval'):
mechanism_calibration.calibrate_dp_mechanism(
lambda: MockAccountant(lambda x: x), MockEvent, 1.0, 0,
UnknownBracketInterval())
def test_raises_mfa_not_callable(self):
with self.assertRaisesRegex(TypeError, 'callable'):
mechanism_calibration.calibrate_dp_mechanism(
'not a callable', MockEvent, 1.0, 0,
mechanism_calibration.ExplicitBracketInterval(0, 5))
def test_raises_mefv_not_callable(self):
with self.assertRaisesRegex(TypeError, 'callable'):
mechanism_calibration.calibrate_dp_mechanism(
lambda: MockAccountant(lambda x: x), 'not a callable', 1.0, 0,
mechanism_calibration.ExplicitBracketInterval(0, 5))
def test_raises_target_epsilon_negative(self):
with self.assertRaisesRegex(ValueError, 'nonnegative'):
mechanism_calibration.calibrate_dp_mechanism(
lambda: MockAccountant(lambda x: x), MockEvent, -1.0, 0,
mechanism_calibration.ExplicitBracketInterval(0, 5))
def test_raises_target_delta_out_of_range(self):
with self.assertRaisesRegex(ValueError, 'in range'):
mechanism_calibration.calibrate_dp_mechanism(
lambda: MockAccountant(lambda x: x), MockEvent, 0.0, -0.1,
mechanism_calibration.ExplicitBracketInterval(0, 5))
with self.assertRaisesRegex(ValueError, 'in range'):
mechanism_calibration.calibrate_dp_mechanism(
lambda: MockAccountant(lambda x: x), MockEvent, 0.0, 1.1,
mechanism_calibration.ExplicitBracketInterval(0, 5))
def test_bad_bracket_interval(self):
with self.assertRaisesRegex(ValueError, 'Bracket endpoints'):
mechanism_calibration.calibrate_dp_mechanism(
lambda: MockAccountant(lambda x: x), MockEvent, 1.0, 0.0,
mechanism_calibration.ExplicitBracketInterval(2, 5))
with self.assertRaisesRegex(ValueError, 'Bracket endpoints'):
mechanism_calibration.calibrate_dp_mechanism(
lambda: MockAccountant(lambda x: x), MockEvent, 1.0, 0.0,
mechanism_calibration.ExplicitBracketInterval(-2, 0))
with self.assertRaisesRegex(ValueError, 'must be less than'):
mechanism_calibration.calibrate_dp_mechanism(
lambda: MockAccountant(lambda x: x), MockEvent, 1.0, 0.0,
mechanism_calibration.LowerEndpointAndGuess(2, 0))
def test_negative_tol(self):
with self.assertRaisesRegex(ValueError, 'tol'):
mechanism_calibration.calibrate_dp_mechanism(
lambda: MockAccountant(lambda x: x), MockEvent, 1.0, 0.0,
mechanism_calibration.LowerEndpointAndGuess(0, 1), tol=-1)
def test_no_bracket_interval_found(self):
with self.assertRaises(mechanism_calibration.NoBracketIntervalFoundError):
mechanism_calibration.calibrate_dp_mechanism(
lambda: MockAccountant(lambda x: x), MockEvent, 1.0e10, 0.0,
mechanism_calibration.LowerEndpointAndGuess(0, 1))
def test_nonempty_accountant(self):
def make_fresh_accountant():
accountant = MockAccountant(lambda x: x)
accountant.compose(MockEvent(1.0))
return accountant
with self.assertRaises(mechanism_calibration.NonEmptyAccountantError):
mechanism_calibration.calibrate_dp_mechanism(
make_fresh_accountant, MockEvent, 0.5, 0.0,
mechanism_calibration.ExplicitBracketInterval(0, 1))
if __name__ == '__main__':
absltest.main()