blob: 055788c9a970ed533d6538d2e0268a968f0cd38e [file] [log] [blame]
# Copyright 2022, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library for calibration of differentially private mechanisms.
Algorithms to optimize some quantity while remaining within a specified privacy
budget.
"""
from typing import Callable, Optional, Union
import attr
from scipy import optimize
from dp_accounting import dp_event
from dp_accounting import privacy_accountant
class BracketInterval(object):
pass
@attr.define(frozen=True)
class ExplicitBracketInterval(BracketInterval):
endpoint_1: float
endpoint_2: float
@attr.define(frozen=True)
class LowerEndpointAndGuess(BracketInterval):
lower_endpoint: float
initial_guess: float
class NoBracketIntervalFoundError(Exception):
"""Error raised when explicit bracket interval cannot be found."""
class NoOptimumFoundError(Exception):
"""Error raised when root finding algorithm fails."""
class NonEmptyAccountantError(Exception):
"""Error raised when result of make_fresh_accountant has nonempty ledger."""
def _search_for_explicit_bracket_interval(
bracket_interval: LowerEndpointAndGuess,
epsilon_gap: Callable[[float], float]) -> ExplicitBracketInterval:
"""Explores exponentially increasing interval to find an explicit bracket.
Args:
bracket_interval: A LowerEndpointAndGuess which will be expanded to find
an explicit interval.
epsilon_gap: Function computing the epsilon at the provided value minus
the target epsilon. It is assumed that this function is monotonic with
respect to its parameter, otherwise the search could fail.
Returns:
A valid ExplicitBracketInterval.
Raises:
NoBracketIntervalFoundError: if no valid bracketing interval is found
within a factor of 2**30 of the initial guess.
"""
lower, upper = attr.astuple(bracket_interval)
if lower >= upper:
raise ValueError(
f'bracket_interval.lower_endpoint ({bracket_interval.lower_endpoint}) '
f'must be less than bracket_interval.initial_guess '
f'({bracket_interval.initial_guess}).')
lower_value = epsilon_gap(lower)
upper_value = epsilon_gap(upper)
gap = upper - lower
num_tries = 0
while lower_value * upper_value > 0:
num_tries += 1
if num_tries > 30:
raise NoBracketIntervalFoundError(
'Unable to find bracketing interval within 2**30 of initial guess. '
'Consider providing an ExplicitBracketInterval.')
gap *= 2 # Loop invariant: gap = initial_gap * (2 ** num_tries).
lower, upper = upper, upper + gap
lower_value, upper_value = upper_value, epsilon_gap(upper)
return ExplicitBracketInterval(lower, upper)
def calibrate_dp_mechanism(
make_fresh_accountant: Callable[[], privacy_accountant.PrivacyAccountant],
make_event_from_param: Union[Callable[[float], dp_event.DpEvent],
Callable[[int], dp_event.DpEvent]],
target_epsilon: float,
target_delta: float,
bracket_interval: Optional[BracketInterval] = None,
discrete: bool = False,
tol: Optional[float] = None) -> Union[float, int]:
"""Searches for optimal mechanism parameter value within privacy budget.
The procedure searches over the space of parameters by creating, for each
sample value, a DpEvent representing the mechanism generated from that value,
and a freshly initialized PrivacyAccountant. Then the accountant is applied to
the event to determine its epsilon at the target delta. Brent's method is used
to determine the value of the parameter at which the target epsilon is
achieved.
Args:
make_fresh_accountant: A callable with no parameters that returns an
initialized PrivacyAccountant. The accountants that are returned across
multiple calls are assumed to be initialized identically. It is an error
for the initialized accountant's `ledger` property to return anything
besides `NoOpDpEvent`.
make_event_from_param: A callable that takes a parameter value as an
argument and creates a `DpEvent` representing the mechanism defined using
that value.
target_epsilon: The target epsilon value.
target_delta: The target delta value.
bracket_interval: A BracketInterval used to determine the upper and lower
endpoints of the interval within which Brent's method will search. If
None, searches for a non-negative bracket starting from [0, 1].
discrete: A bool determining whether the parameter is continuous or discrete
valued. If True, the parameter is assumed to take only integer values.
Concretely, `discrete=True` has three effects. 1) ints, not floats are
passed to `make_event_from_param`. 2) The minimum optimization tolerance
is 0.5. 3) An integer is returned.
tol: The tolerance, in parameter space. If the maximum (or minimum) value of
the parameter that meets the privacy requirements is x*,
calibrate_dp_mechanism is guaranteed to return a value x such that |x -
x*| <= tol. If `None`, tol is set to 1e-6 for continuous parameters or 0.5
for discrete parameters.
Returns:
A value of the parameter within tol of the optimum subject to the privacy
constraint. If discrete=True, the returned value will be an integer.
Otherwise it will be a float.
Raises:
NoBracketIntervalFoundError: if bracket_interval is LowerEndpointAndGuess
and no upper bound can be found within a factor of 2**30 of the original
guess.
NoOptimumFoundError: if scipy.optimize.brentq fails to find an optimum.
NonEmptyAccountantError: if make_fresh_accountant returns an accountant with
nonempty ledger.
"""
if not callable(make_fresh_accountant):
raise TypeError(f'make_fresh_accountant must be callable. '
f'found {type(make_fresh_accountant)}.')
if not callable(make_event_from_param):
raise TypeError(f'make_fresh_accountant must be callable. '
f'found {type(make_fresh_accountant)}.')
if target_epsilon < 0:
raise ValueError(f'target_epsilon must be nonnegative. Found '
f'{target_epsilon}.')
if not 0 <= target_delta <= 1:
raise ValueError(f'target_delta must be in range [0, 1]. Found '
f'{target_delta}.')
if bracket_interval is None:
bracket_interval = LowerEndpointAndGuess(0, 1)
if tol is None:
tol = 0.5 if discrete else 1e-6
elif discrete:
tol = max(tol, 0.5)
elif tol <= 0:
raise ValueError(f'tol must be positive. Found {tol}.')
def epsilon_gap(x: float) -> float:
if discrete:
x = round(x)
event = make_event_from_param(x)
accountant = make_fresh_accountant()
if not isinstance(accountant.ledger, dp_event.NoOpDpEvent):
raise NonEmptyAccountantError()
return accountant.compose(event).get_epsilon(target_delta) - target_epsilon
if isinstance(bracket_interval, LowerEndpointAndGuess):
bracket_interval = _search_for_explicit_bracket_interval(
bracket_interval, epsilon_gap)
elif not isinstance(bracket_interval, ExplicitBracketInterval):
raise TypeError(f'Unrecognized bracket_interval type: '
f'{type(bracket_interval)}')
try:
root, result = optimize.brentq(
epsilon_gap,
bracket_interval.endpoint_1,
bracket_interval.endpoint_2,
xtol=tol,
full_output=True)
except ValueError as err:
raise ValueError(
'`brentq` raised ValueError. This often means the supplied bracket '
f'interval {bracket_interval} did not bracket a solution.') from err
if not result.converged:
raise NoOptimumFoundError(
'Unable to find root with scipy.optimize.brentq.')
if epsilon_gap(root) > 0:
# Ensure that gap is not positive, guaranteeing returned parameter gives no
# less privacy than was requested.
if epsilon_gap(root + tol) < 0:
root += tol
elif epsilon_gap(root - tol) < 0:
root -= tol
else:
raise NoOptimumFoundError(
f'Unable to find valid value near root {root} returned by brentq.')
if discrete:
root = round(root)
return root