| # Copyright 2020 Google LLC. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Common classes and functions for the accounting library.""" |
| |
| import dataclasses |
| import math |
| from typing import Callable, List, Mapping, Optional, Tuple, Union |
| |
| import numpy as np |
| from scipy import fft |
| from scipy import signal |
| from scipy import special |
| |
| ArrayLike = Union[np.ndarray, List[float]] |
| |
| |
| @dataclasses.dataclass |
| class DifferentialPrivacyParameters(object): |
| """Representation of the differential privacy parameters of a mechanism. |
| |
| Attributes: |
| epsilon: the epsilon in (epsilon, delta)-differential privacy. |
| delta: the delta in (epsilon, delta)-differential privacy. |
| """ |
| epsilon: float |
| delta: float = 0 |
| |
| def __post_init__(self): |
| if self.epsilon < 0: |
| raise ValueError(f'epsilon should be positive: {self.epsilon}') |
| if self.delta < 0 or self.delta > 1: |
| raise ValueError(f'delta should be between 0 and 1: {self.delta}') |
| |
| |
| @dataclasses.dataclass |
| class BinarySearchParameters(object): |
| """Parameters used for binary search. |
| |
| Attributes: |
| upper_bound: An upper bound on the binary search range. |
| lower_bound: A lower bound on the binary search range. |
| initial_guess: An initial guess to start the search with. Must be positive. |
| When this guess is close to the true value, it can help make the binary |
| search faster. |
| tolerance: An acceptable error on the returned value. |
| discrete: Whether the search is over integers. |
| """ |
| lower_bound: float |
| upper_bound: float |
| initial_guess: Optional[float] = None |
| tolerance: float = 1e-7 |
| discrete: bool = False |
| |
| |
| def inverse_monotone_function(func: Callable[[float], float], |
| value: float, |
| search_parameters: BinarySearchParameters, |
| increasing: bool = False) -> Optional[float]: |
| """Inverse a monotone function. |
| |
| Args: |
| func: The function to be inversed. |
| value: The desired value of the function. |
| search_parameters: Parameters used for binary search. |
| increasing: Whether the function is monotonically increasing. |
| |
| Returns: |
| x such that func(x) is no more than value, when such x exists. It is |
| guaranteed that the returned x is within search_parameters.tolerance of the |
| smallest (for monotonically decreasing func) or the largest (for |
| monotonically increasing func) such x. When no such x exists within the |
| given range, returns None. |
| """ |
| lower_x = search_parameters.lower_bound |
| upper_x = search_parameters.upper_bound |
| initial_guess_x = search_parameters.initial_guess |
| |
| if increasing: |
| check = lambda func_value, target_value: func_value <= target_value |
| if lower_x != -math.inf and func(lower_x) > value: |
| return None |
| else: |
| check = lambda func_value, target_value: func_value > target_value |
| if upper_x != math.inf and func(upper_x) > value: |
| return None |
| |
| if initial_guess_x is not None: |
| while initial_guess_x < upper_x and check(func(initial_guess_x), value): |
| lower_x = initial_guess_x |
| initial_guess_x *= 2 |
| upper_x = min(upper_x, initial_guess_x) |
| |
| if search_parameters.discrete: |
| tolerance = 1 |
| else: |
| tolerance = search_parameters.tolerance |
| |
| while upper_x - lower_x > tolerance: |
| if search_parameters.discrete: |
| mid_x = (upper_x + lower_x) // 2 |
| else: |
| mid_x = (upper_x + lower_x) / 2 |
| |
| if check(func(mid_x), value): |
| lower_x = mid_x |
| else: |
| upper_x = mid_x |
| |
| if increasing: |
| return lower_x |
| else: |
| return upper_x |
| |
| |
| def dictionary_to_list( |
| input_dictionary: Mapping[int, float]) -> Tuple[int, List[float]]: |
| """Converts an integer-keyed dictionary into an list. |
| |
| Args: |
| input_dictionary: A dictionary whose keys are integers. |
| |
| Returns: |
| A tuple of an integer offset and a list result_list. The offset is the |
| minimum value of the input dictionary. result_list has length equal to the |
| difference between the maximum and minimum values of the input dictionary. |
| result_list[i] is equal to dictionary[offset + i] and is zero if offset + i |
| is not a key in the input dictionary. |
| """ |
| offset = min(input_dictionary) |
| max_val = max(input_dictionary) |
| result_list = [input_dictionary.get(i, 0) for i in range(offset, max_val + 1)] |
| return (offset, result_list) |
| |
| |
| def list_to_dictionary(input_list: List[float], |
| offset: int, |
| tail_mass_truncation: float = 0) -> Mapping[int, float]: |
| """Converts a list into an integer-keyed dictionary, with a specified offset. |
| |
| Args: |
| input_list: An input list. |
| offset: The offset in the key of the output dictionary |
| tail_mass_truncation: an upper bound on the tails of the input list that |
| might be truncated. |
| |
| Returns: |
| A dictionary whose value at key is equal to input_list[key - offset]. If |
| input_list[key - offset] is less than or equal to zero, it is not included |
| in the dictionary. |
| """ |
| lower_truncation_index = 0 |
| lower_truncation_mass = 0 |
| while lower_truncation_index < len(input_list): |
| lower_truncation_mass += input_list[lower_truncation_index] |
| if lower_truncation_mass > tail_mass_truncation / 2: |
| break |
| lower_truncation_index += 1 |
| |
| upper_truncation_index = len(input_list) - 1 |
| upper_truncation_mass = 0 |
| while upper_truncation_index >= 0: |
| upper_truncation_mass += input_list[upper_truncation_index] |
| if upper_truncation_mass > tail_mass_truncation / 2: |
| break |
| upper_truncation_index -= 1 |
| |
| result_dictionary = {} |
| for i in range(lower_truncation_index, upper_truncation_index + 1): |
| if input_list[i] > 0: |
| result_dictionary[i + offset] = input_list[i] |
| return result_dictionary |
| |
| |
| def convolve_dictionary(dictionary1: Mapping[int, float], |
| dictionary2: Mapping[int, float], |
| tail_mass_truncation: float = 0) -> Mapping[int, float]: |
| """Computes a convolution of two dictionaries. |
| |
| Args: |
| dictionary1: The first dictionary whose keys are integers. |
| dictionary2: The second dictionary whose keys are integers. |
| tail_mass_truncation: an upper bound on the tails of the output that might |
| be truncated. |
| |
| Returns: |
| The dictionary where for each key its corresponding value is the sum, over |
| all key1, key2 such that key1 + key2 = key, of dictionary1[key1] times |
| dictionary2[key2] |
| """ |
| |
| # Convert the dictionaries to lists. |
| min1, list1 = dictionary_to_list(dictionary1) |
| min2, list2 = dictionary_to_list(dictionary2) |
| |
| # Compute the convolution of the two lists. |
| result_list = signal.fftconvolve(list1, list2) |
| |
| # Convert the list back to a dictionary and return |
| return list_to_dictionary( |
| result_list, min1 + min2, tail_mass_truncation=tail_mass_truncation) |
| |
| |
| def compute_self_convolve_bounds( |
| input_list: List[float], |
| num_times: int, |
| tail_mass_truncation: float = 0, |
| orders: Optional[List[float]] = None) -> Tuple[int, int]: |
| """Computes truncation bounds for convolution using Chernoff bound. |
| |
| Args: |
| input_list: The input list to be convolved. |
| num_times: The number of times the list is to be convolved with itself. |
| tail_mass_truncation: an upper bound on the tails of the output that might |
| be truncated. |
| orders: a list of orders on which the Chernoff bound is applied. |
| |
| Returns: |
| A pair of upper and lower bounds for which the mass of the result of |
| convolution outside of this range is at most tail_mass_truncation. |
| """ |
| upper_bound = (len(input_list) - 1) * num_times |
| lower_bound = 0 |
| |
| if tail_mass_truncation == 0: |
| return lower_bound, upper_bound |
| |
| if orders is None: |
| # Set orders so whose absolute values are not too large; otherwise, we may |
| # run into numerical issues. |
| orders = ( |
| np.concatenate((np.arange(-20, 0), np.arange(1, 21))) / len(input_list)) |
| |
| # Compute log of the moment generating function at the specified orders. |
| log_mgfs = [ |
| special.logsumexp(np.arange(len(input_list)) * order, b=input_list) |
| for order in orders |
| ] |
| |
| for order, log_mgf_value in zip(orders, log_mgfs): |
| # Use Chernoff bound to update the upper/lower bound. See equation (5) in |
| # the supplementary material. |
| bound = (num_times * log_mgf_value + |
| math.log(2 / tail_mass_truncation)) / order |
| if order > 0: |
| upper_bound = min(upper_bound, math.ceil(bound)) |
| if order < 0: |
| lower_bound = max(lower_bound, math.floor(bound)) |
| |
| return lower_bound, upper_bound |
| |
| |
| def self_convolve(input_list: ArrayLike, |
| num_times: int, |
| tail_mass_truncation: float = 0) -> Tuple[int, List[float]]: |
| """Computes a convolution of the input list with itself num_times times. |
| |
| Args: |
| input_list: The input list to be convolved. |
| num_times: The number of times the list is to be convolved with itself. |
| tail_mass_truncation: an upper bound on the tails of the output that might |
| be truncated. |
| |
| Returns: |
| A pair of truncation_lower_bound, output_list, where the i-th entry of |
| output_list is approximately the sum, over all i_1, i_2, ..., i_num_times |
| such that i_1 + i_2 + ... + i_num_times = i + truncation_lower_bound, |
| of input_list[i_1] * input_list[i_2] * ... * input_list[i_num_times]. |
| """ |
| truncation_lower_bound, truncation_upper_bound = compute_self_convolve_bounds( |
| input_list, num_times, tail_mass_truncation) |
| |
| # Use FFT to compute the convolution |
| output_len = truncation_upper_bound - truncation_lower_bound + 1 |
| fast_len = fft.next_fast_len(max(output_len, len(input_list))) |
| truncated_convolution_output = np.real( |
| fft.ifft(fft.fft(input_list, fast_len)**num_times)) |
| |
| # Discrete Fourier Transform wraps around modulo fast_len. Extract the output |
| # values in the range of interest. |
| output_list = np.roll( |
| truncated_convolution_output, -truncation_lower_bound |
| )[:output_len] |
| |
| return truncation_lower_bound, output_list |
| |
| |
| def self_convolve_dictionary( |
| input_dictionary: Mapping[int, float], |
| num_times: int, |
| tail_mass_truncation: float = 0) -> Mapping[int, float]: |
| """Computes a convolution of the input dictionary with itself num_times times. |
| |
| Args: |
| input_dictionary: The input dictionary whose keys are integers. |
| num_times: The number of times the dictionary is to be convolved with |
| itself. |
| tail_mass_truncation: an upper bound on the tails of the output that might |
| be truncated. |
| |
| Returns: |
| The dictionary where for each key its corresponding value is the sum, over |
| all key1, key2, ..., key_num_times such that key1 + key2 + ... + |
| key_num_times = key, of input_dictionary[key1] * input_dictionary[key2] * |
| ... * input_dictionary[key_num_times] |
| """ |
| min_val, input_list = dictionary_to_list(input_dictionary) |
| min_val_convolution, output_list = self_convolve( |
| input_list, num_times, tail_mass_truncation=tail_mass_truncation) |
| return list_to_dictionary(output_list, |
| min_val * num_times + min_val_convolution) |
| |
| |
| def _log_add(a: float, b: float) -> float: |
| """Returns log(exp(a) + exp(b)).""" |
| mn, mx = min(a, b), max(a, b) |
| return mx + np.log1p(np.exp(mn - mx)) |
| |
| |
| def _log_sub(a: float, b: float) -> float: |
| """Returns log(exp(a) - exp(b)).""" |
| if b >= a: |
| raise ValueError(f'a must be greater than b. Got a={a} and b={b}.') |
| return a + np.log1p(-np.exp(b - a)) |
| |
| |
| def log_a_times_exp_b_plus_c(a: float, b: float, c: float) -> float: |
| """Computes log(a * exp(b) + c).""" |
| if a == 0: |
| return np.log(c) |
| if a < 0: |
| if c <= 0: |
| raise ValueError(f'a exp(b) + c must be positive: {a}, {b}, {c}.') |
| return _log_sub(np.log(c), np.log(-a) + b) |
| if b == 0: |
| return np.log(a + c) |
| d = b + np.log(a) |
| if c == 0: |
| return d |
| elif c < 0: |
| return _log_sub(d, np.log(-c)) |
| else: |
| return _log_add(d, np.log(c)) |