| # Copyright 2022, The TensorFlow Authors. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # https://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Library for calibration of differentially private mechanisms. |
| |
| Algorithms to optimize some quantity while remaining within a specified privacy |
| budget. |
| """ |
| |
| from typing import Callable, Optional, Union |
| |
| import attr |
| from scipy import optimize |
| |
| from dp_accounting import dp_event |
| from dp_accounting import privacy_accountant |
| |
| |
| class BracketInterval(object): |
| pass |
| |
| |
| @attr.define(frozen=True) |
| class ExplicitBracketInterval(BracketInterval): |
| endpoint_1: float |
| endpoint_2: float |
| |
| |
| @attr.define(frozen=True) |
| class LowerEndpointAndGuess(BracketInterval): |
| lower_endpoint: float |
| initial_guess: float |
| |
| |
| class NoBracketIntervalFoundError(Exception): |
| """Error raised when explicit bracket interval cannot be found.""" |
| |
| |
| class NoOptimumFoundError(Exception): |
| """Error raised when root finding algorithm fails.""" |
| |
| |
| class NonEmptyAccountantError(Exception): |
| """Error raised when result of make_fresh_accountant has nonempty ledger.""" |
| |
| |
| def _search_for_explicit_bracket_interval( |
| bracket_interval: LowerEndpointAndGuess, |
| epsilon_gap: Callable[[float], float]) -> ExplicitBracketInterval: |
| """Explores exponentially increasing interval to find an explicit bracket. |
| |
| Args: |
| bracket_interval: A LowerEndpointAndGuess which will be expanded to find |
| an explicit interval. |
| epsilon_gap: Function computing the epsilon at the provided value minus |
| the target epsilon. It is assumed that this function is monotonic with |
| respect to its parameter, otherwise the search could fail. |
| |
| Returns: |
| A valid ExplicitBracketInterval. |
| |
| Raises: |
| NoBracketIntervalFoundError: if no valid bracketing interval is found |
| within a factor of 2**30 of the initial guess. |
| """ |
| lower, upper = attr.astuple(bracket_interval) |
| if lower >= upper: |
| raise ValueError( |
| f'bracket_interval.lower_endpoint ({bracket_interval.lower_endpoint}) ' |
| f'must be less than bracket_interval.initial_guess ' |
| f'({bracket_interval.initial_guess}).') |
| |
| lower_value = epsilon_gap(lower) |
| upper_value = epsilon_gap(upper) |
| |
| gap = upper - lower |
| num_tries = 0 |
| |
| while lower_value * upper_value > 0: |
| num_tries += 1 |
| if num_tries > 30: |
| raise NoBracketIntervalFoundError( |
| 'Unable to find bracketing interval within 2**30 of initial guess. ' |
| 'Consider providing an ExplicitBracketInterval.') |
| |
| gap *= 2 # Loop invariant: gap = initial_gap * (2 ** num_tries). |
| lower, upper = upper, upper + gap |
| lower_value, upper_value = upper_value, epsilon_gap(upper) |
| |
| return ExplicitBracketInterval(lower, upper) |
| |
| |
| def calibrate_dp_mechanism( |
| make_fresh_accountant: Callable[[], privacy_accountant.PrivacyAccountant], |
| make_event_from_param: Union[Callable[[float], dp_event.DpEvent], |
| Callable[[int], dp_event.DpEvent]], |
| target_epsilon: float, |
| target_delta: float, |
| bracket_interval: Optional[BracketInterval] = None, |
| discrete: bool = False, |
| tol: Optional[float] = None) -> Union[float, int]: |
| """Searches for optimal mechanism parameter value within privacy budget. |
| |
| The procedure searches over the space of parameters by creating, for each |
| sample value, a DpEvent representing the mechanism generated from that value, |
| and a freshly initialized PrivacyAccountant. Then the accountant is applied to |
| the event to determine its epsilon at the target delta. Brent's method is used |
| to determine the value of the parameter at which the target epsilon is |
| achieved. |
| |
| Args: |
| make_fresh_accountant: A callable with no parameters that returns an |
| initialized PrivacyAccountant. The accountants that are returned across |
| multiple calls are assumed to be initialized identically. It is an error |
| for the initialized accountant's `ledger` property to return anything |
| besides `NoOpDpEvent`. |
| make_event_from_param: A callable that takes a parameter value as an |
| argument and creates a `DpEvent` representing the mechanism defined using |
| that value. |
| target_epsilon: The target epsilon value. |
| target_delta: The target delta value. |
| bracket_interval: A BracketInterval used to determine the upper and lower |
| endpoints of the interval within which Brent's method will search. If |
| None, searches for a non-negative bracket starting from [0, 1]. |
| discrete: A bool determining whether the parameter is continuous or discrete |
| valued. If True, the parameter is assumed to take only integer values. |
| Concretely, `discrete=True` has three effects. 1) ints, not floats are |
| passed to `make_event_from_param`. 2) The minimum optimization tolerance |
| is 0.5. 3) An integer is returned. |
| tol: The tolerance, in parameter space. If the maximum (or minimum) value of |
| the parameter that meets the privacy requirements is x*, |
| calibrate_dp_mechanism is guaranteed to return a value x such that |x - |
| x*| <= tol. If `None`, tol is set to 1e-6 for continuous parameters or 0.5 |
| for discrete parameters. |
| |
| Returns: |
| A value of the parameter within tol of the optimum subject to the privacy |
| constraint. If discrete=True, the returned value will be an integer. |
| Otherwise it will be a float. |
| |
| Raises: |
| NoBracketIntervalFoundError: if bracket_interval is LowerEndpointAndGuess |
| and no upper bound can be found within a factor of 2**30 of the original |
| guess. |
| NoOptimumFoundError: if scipy.optimize.brentq fails to find an optimum. |
| NonEmptyAccountantError: if make_fresh_accountant returns an accountant with |
| nonempty ledger. |
| """ |
| |
| if not callable(make_fresh_accountant): |
| raise TypeError(f'make_fresh_accountant must be callable. ' |
| f'found {type(make_fresh_accountant)}.') |
| |
| if not callable(make_event_from_param): |
| raise TypeError(f'make_fresh_accountant must be callable. ' |
| f'found {type(make_fresh_accountant)}.') |
| |
| if target_epsilon < 0: |
| raise ValueError(f'target_epsilon must be nonnegative. Found ' |
| f'{target_epsilon}.') |
| |
| if not 0 <= target_delta <= 1: |
| raise ValueError(f'target_delta must be in range [0, 1]. Found ' |
| f'{target_delta}.') |
| |
| if bracket_interval is None: |
| bracket_interval = LowerEndpointAndGuess(0, 1) |
| |
| if tol is None: |
| tol = 0.5 if discrete else 1e-6 |
| elif discrete: |
| tol = max(tol, 0.5) |
| elif tol <= 0: |
| raise ValueError(f'tol must be positive. Found {tol}.') |
| |
| def epsilon_gap(x: float) -> float: |
| if discrete: |
| x = round(x) |
| event = make_event_from_param(x) |
| accountant = make_fresh_accountant() |
| if not isinstance(accountant.ledger, dp_event.NoOpDpEvent): |
| raise NonEmptyAccountantError() |
| return accountant.compose(event).get_epsilon(target_delta) - target_epsilon |
| |
| if isinstance(bracket_interval, LowerEndpointAndGuess): |
| bracket_interval = _search_for_explicit_bracket_interval( |
| bracket_interval, epsilon_gap) |
| elif not isinstance(bracket_interval, ExplicitBracketInterval): |
| raise TypeError(f'Unrecognized bracket_interval type: ' |
| f'{type(bracket_interval)}') |
| |
| try: |
| root, result = optimize.brentq( |
| epsilon_gap, |
| bracket_interval.endpoint_1, |
| bracket_interval.endpoint_2, |
| xtol=tol, |
| full_output=True) |
| except ValueError as err: |
| raise ValueError( |
| '`brentq` raised ValueError. This often means the supplied bracket ' |
| f'interval {bracket_interval} did not bracket a solution.') from err |
| |
| if not result.converged: |
| raise NoOptimumFoundError( |
| 'Unable to find root with scipy.optimize.brentq.') |
| |
| if epsilon_gap(root) > 0: |
| # Ensure that gap is not positive, guaranteeing returned parameter gives no |
| # less privacy than was requested. |
| if epsilon_gap(root + tol) < 0: |
| root += tol |
| elif epsilon_gap(root - tol) < 0: |
| root -= tol |
| else: |
| raise NoOptimumFoundError( |
| f'Unable to find valid value near root {root} returned by brentq.') |
| |
| if discrete: |
| root = round(root) |
| |
| return root |