blob: 9aeb72cc0c5f8ee988f2bb838b31ec9789d89388 [file] [log] [blame]
//
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#ifndef DIFFERENTIAL_PRIVACY_ALGORITHMS_UTIL_H_
#define DIFFERENTIAL_PRIVACY_ALGORITHMS_UTIL_H_
#include <algorithm>
#include <cmath>
#include <limits>
#include <numeric>
#include <optional>
#include <string>
#include <type_traits>
#include <vector>
#include <cstdint>
#include "base/logging.h"
#include "absl/base/attributes.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "base/status_macros.h"
namespace differential_privacy {
// Arbitrary default value for epsilon. The algorithm interface falls back on
// this value whenever one is not provided. This value should only be used for
// testing convenience. For any production use case, please set your own epsilon
// based on privacy needs.
ABSL_DEPRECATED("Use your own epsilon based on privacy considerations.")
double DefaultEpsilon();
// Returns the smallest power of 2 greater than or equal to n. n must be > 0.
// Includes negative powers.
double GetNextPowerOfTwo(double n);
// Rounds n to the nearest multiple of base. Ties are broken towards +inf.
// If base is 0, returns n.
double RoundToNearestDoubleMultiple(double n, double base);
int64_t RoundToNearestInt64Multiple(int64_t n, int64_t base);
// Templates are needed for RoundToNearestMultiple(), since without them and
// instead trying to overload RoundToNearestMultiple() causes C++ compiler
// errors stating, for example, RoundToNearestMultiple(5, 3) is ambiguous.
template <typename T, std::enable_if_t<std::is_integral<T>::value>* = nullptr>
T RoundToNearestMultiple(T n, T base) {
return RoundToNearestInt64Multiple(n, base);
}
template <typename T,
std::enable_if_t<std::is_floating_point<T>::value>* = nullptr>
T RoundToNearestMultiple(T n, T base) {
return RoundToNearestDoubleMultiple(n, base);
}
// Return 1 if n > 0, -1 if n < 0, and 0 if n == 0.
template <typename T>
T sign(T n) {
if (n > 0) return 1;
if (n < 0) return -1;
return 0;
}
// Approximate the inverse of the error function.
// Implementation based on Table 5 in Giles' paper
// on approximating the inverse of the error function
// (https://people.maths.ox.ac.uk/gilesm/files/gems_erfinv.pdf).
double InverseErrorFunction(double x);
// Estimation of the inverse cdf of the normal distribution centered at mu with
// standard deviation sigma, at probability p. Based on Abramowitz and Stegun
// formula 26.2.23. The error of the estimation is bounded by 4.5 e-4. This
// function will fail if higher accuracy is required.
absl::StatusOr<double> Qnorm(double p, double mu = 0.0, double sigma = 1.0);
template <typename T>
inline const T& Clamp(const T& low, const T& high, const T& value) {
// Prevents errors in ordering the arguments.
DCHECK(!(high < low));
if (high < value) return high;
if (value < low) return low;
return value;
}
// Return value for the Safe* operation functions below, including the cast
// resulting value of the operation and whether or not the operation caused an
// overflow.
template <typename T>
struct SafeOpResult {
T value;
bool overflow = false;
};
// When T is an integral type, return the addition result if and whether or not
// there would have been an overflow. Otherwise, assign the numeric limit to
// result and signal that there would have been an overflow.
// Note that this should NOT be used to gracefully handle overflows in
// computations on data. See (broken link)
template <typename T, std::enable_if_t<std::is_integral<T>::value>* = nullptr>
inline SafeOpResult<T> SafeAdd(T lhs, T rhs) {
if (lhs > 0) {
// For negative rhs, we will never overflow.
if (rhs > 0) {
T safe_distance = std::numeric_limits<T>::max() - lhs;
if (safe_distance < rhs) {
return SafeOpResult<T>{std::numeric_limits<T>::max(), true};
}
}
} else if (lhs < 0) {
// For positive rhs, we will never overflow.
if (rhs < 0) {
T safe_distance = std::numeric_limits<T>::lowest() - lhs;
if (safe_distance > rhs) {
return SafeOpResult<T>{std::numeric_limits<T>::lowest(), true};
}
}
}
return SafeOpResult<T>{lhs + rhs, false};
}
// When T is a floating-point type, perform a simple addition, since
// floating-point types don't have the same overflow issues as integral types.
// Note that this should NOT be used to gracefully handle overflows in
// computations on data. See (broken link)
template <typename T,
std::enable_if_t<std::is_floating_point<T>::value>* = nullptr>
inline SafeOpResult<T> SafeAdd(T lhs, T rhs) {
return SafeOpResult<T>{lhs + rhs, false};
}
// When T is an integral type, assign the subtraction result and whether or not
// there was an overflow. Otherwise, assign the numeric limit to result and
// that there would have been an overflow.
// Note that this should NOT be used to gracefully handle overflows in
// computations on data. See (broken link)
template <typename T, std::enable_if_t<std::is_integral<T>::value>* = nullptr>
inline SafeOpResult<T> SafeSubtract(T lhs, T rhs) {
// For integral values, the min numeric limit is larger in magnitude than the
// max numeric limit, so we cannot negate it. For unsigned types, the lowest
// numeric limit is 0. For signed types, it is negative.
if (rhs == std::numeric_limits<T>::lowest() && rhs != 0) {
if (lhs >= 0) {
// We use std::numeric_limits<T>::max() here, since we assume that
// std::numeric_limits<T>::max() <= -(-std::numeric_limits<T>::lowest()).
return SafeOpResult<T>{std::numeric_limits<T>::max(), true};
} else {
return SafeOpResult<T>{lhs - rhs, false};
}
}
// For all other values of rhs, add the negation.
return SafeAdd(lhs, -rhs);
}
// When T is a floating-point type, perform a simple subtraction, since
// floating-point types don't have the same overflow issues as integral types.
// Note that this should NOT be used to gracefully handle overflows in
// computations on data. See (broken link)
template <typename T,
std::enable_if_t<std::is_floating_point<T>::value>* = nullptr>
inline SafeOpResult<T> SafeSubtract(T lhs, T rhs) {
return SafeOpResult<T>{lhs - rhs, false};
}
// Return true and assign the square result if squaring will not overflow.
// Note that this should NOT be used to gracefully handle overflows in
// computations on data. See (broken link)
template <typename T, std::enable_if_t<std::is_integral<T>::value>* = nullptr>
inline SafeOpResult<T> SafeSquare(T num) {
SafeOpResult<T> safe_op_result;
double max_root = std::pow(std::numeric_limits<T>::max(), 0.5);
if ((num > 0 && num > static_cast<T>(max_root)) ||
(num < 0 && num < -1 * static_cast<T>(max_root))) {
safe_op_result.overflow = true;
safe_op_result.value = 0;
} else {
safe_op_result.overflow = false;
safe_op_result.value = num * num;
}
return safe_op_result;
}
// Tries to convert a double value to an integral value, manually overflowing
// if necessary to avoid a SIGILL error from a static_cast outside the numeric
// limits of T. Returns a pair containing the the cast (and possibly
// overflowed) value and a boolean indicating whether or not the cast would have
// been successful (i.e., true if the cast would have overflowed).
template <typename T, std::enable_if_t<std::is_integral<T>::value>* = nullptr>
inline SafeOpResult<T> SafeCastFromDouble(const double in) {
if (std::isnan(in) || !std::isfinite(in)) {
// Integral types do not support NaN or infinite values.
return SafeOpResult<T>{std::numeric_limits<T>::quiet_NaN(), true};
}
static const int64_t kTMax = std::numeric_limits<T>::max();
static const int64_t kTLowest = std::numeric_limits<T>::lowest();
double t_range_size = 1.0 + kTMax - kTLowest;
bool overflow = false;
double d_out = in;
if (d_out > kTMax) {
overflow = true;
// Translate `d_out` into the range of T, where
// `std::round(d_out / t_range_size)` is the number of times `d_out` would
// have overflowed outside of the range of T. For example, suppose:
// T = int16_t;
// d_out = 40000; which is > MAX_INT16 (== 32767)
// It follows that:
// t_range_size = 32767 - (-32768) + 1 = 65536;
// d_out = d_out - t_range_size * std::round(d_out / t_range_size)
// = 40000 - 65536 * std::round(40000 / 65536)
// = 40000 - 65536 * 1
// = 40000 - 65536
// = -25536;
// This result is the same as an overflowed int16_t, such that:
// decimal -> int16_t
// -----------------
// 0 -> 0
// 1 -> 1
// 2 -> 2
// ...
// 32766 -> 32766
// 32767 -> 32767
// 32768 -> -32768 because of an int16_t overflow
// 32769 -> -32767
// 32770 -> -32766
// ...
// 39999 -> -25537
// 40000 -> -25536
d_out -= t_range_size * std::round(d_out / t_range_size);
}
if (d_out < kTLowest) {
overflow = true;
// Translate `d_out` into the range of T, where
// `std::round(d_out / t_range_size)` is the number of times `d_out` would
// have underflowed outside of the range of T. For example, suppose:
// T = int16_t;
// d_out = -40000; which is < LOWEST_INT16 (== -32768)
// It follows that:
// t_range_size = 32767 - (-32768) + 1 = 65536;
// d_out = d_out + t_range_size * std::round(-d_out / t_range_size)
// = -40000 + 65536 * std::round(-(-40000) / 65536)
// = -40000 + 65536 * std::round(40000 / 65536)
// = -40000 + 65536 * 1
// = -40000 + 65536
// = 25536;
// This result is the same as an underflowed int16_t, such that:
// decimal -> int16_t
// -----------------
// 0 -> 0
// -1 -> -1
// -2 -> -2
// ...
// -32766 -> -32766
// -32767 -> -32767
// -32768 -> -32768
// -32769 -> 32767 because of an int16_t overflow
// -32770 -> 32766
// ...
// -39999 -> 25537
// -40000 -> 25536
d_out += t_range_size * std::round(-d_out / t_range_size);
}
double d_out_floor = std::trunc(d_out);
// Since floating-point variables are only approximations of values (and not
// the precise value itself), they can still have residual decimal values that
// are outside of the numeric limits of T, which would cause a static_cast to
// crash with a SIGILL error. To illustrate, if `d_out` == MAX_INT64, then
// `static_cast<int64_t>(d_out)` will cause a SIGILL error, because the precise
// value of d_out is actually larger than MAX_INT64 (i.e., at such large
// magnitudes, doubles are actually exact integers, but many fewer integers
// can be accurately represented, since the double-precision format can only
// inaccurately approximate them). To prevent this, we try to simply set `out`
// to the numerical limit when `d_out` is close enough to the numerical limit.
T out;
if (d_out_floor >= kTMax) {
out = kTMax;
} else if (d_out_floor <= kTLowest) {
out = kTLowest;
} else {
out = static_cast<T>(d_out_floor);
}
return SafeOpResult<T>{out, overflow};
}
// Converts double to other floating points. This should be mostly a no-op since
// we are typically only using doubles.
template <typename T,
std::enable_if_t<std::is_floating_point<T>::value>* = nullptr>
inline SafeOpResult<T> SafeCastFromDouble(const double in) {
return SafeOpResult<T>{static_cast<T>(in), false};
}
template <typename T>
inline double Mean(const std::vector<T>& v) {
if (v.empty()) {
return 0.0;
}
return std::accumulate(v.begin(), v.end(), 0.0) / v.size();
}
template <typename T>
inline double Variance(const std::vector<T>& v) {
if (v.empty()) {
return 0.0;
}
double mean = Mean(v);
double var = 0;
for (const T& num : v) {
var += std::pow(num - mean, 2);
}
return var / v.size();
}
template <typename T>
inline double StandardDev(const std::vector<T>& v) {
return std::pow(Variance(v), .5);
}
// Percentile should be between 0 and 1. Does linear interpolation between
// nearest indices.
template <typename T>
inline T OrderStatistic(double percentile, const std::vector<T>& v) {
std::vector<T> values = std::vector<T>(v);
std::sort(values.begin(), values.end());
const int n = values.size();
if (n == 0) return 0.0;
const double pos = n * percentile - 0.5;
if (pos <= 0.0) return values[0];
if (pos >= n - 1) return values[n - 1];
const int index = static_cast<const int>(pos);
const double fraction = pos - index;
return (1.0 - fraction) * v[index] + fraction * v[index + 1];
}
// Given two numeric vectors of equal length, returns their linear correlation
// coefficient, or NaN if a variance is zero. Return NaN for unequal length
// vectors as well.
template <typename T>
double Correlation(const std::vector<T>& x, const std::vector<T>& y) {
int n = x.size();
if (n < 2 || n != y.size()) {
return NAN;
}
// First get the means.
T sum_x = 0.0;
T sum_y = 0.0;
for (int i = 0; i < n; ++i) {
sum_x += x[i];
sum_y += y[i];
}
const double mean_x = sum_x / n;
const double mean_y = sum_y / n;
// Then the variances and covariance.
double sum_xx = 0.0;
double sum_yy = 0.0;
double sum_xy = 0.0;
for (int i = 0; i < n; ++i) {
const double delta_x = x[i] - mean_x;
const double delta_y = y[i] - mean_y;
sum_xx += delta_x * delta_x;
sum_xy += delta_x * delta_y;
sum_yy += delta_y * delta_y;
}
// Return the correlation coefficient, or NaN if variance in x or y is almost
// 0.0.
const double error = std::pow(10, -10);
if (sum_xx > error && sum_yy > error) {
return sum_xy / std::sqrt(sum_xx * sum_yy);
} else {
return NAN;
}
}
// Filter a vector v using a selection vector. The selection vector has true
// at an index i if that element is selected. Return a vector of only the
// selected elements in v, preserving order.
template <typename T>
std::vector<T> VectorFilter(const std::vector<T>& v,
const std::vector<bool>& selection) {
std::vector<T> result;
DCHECK(v.size() == selection.size());
for (int i = 0; i < std::min(v.size(), selection.size()); ++i) {
if (selection[i]) {
result.push_back(v[i]);
}
}
return result;
}
// Transform vector into a pretty std::string.
template <typename T>
std::string VectorToString(const std::vector<T>& v) {
return absl::StrCat("[", absl::StrJoin(v, ", "), "]");
}
// The functions below provide a common and consistent way for validating
// arguments and formatting error messages.
// Returns absl::OkStatus() if the value of optional `opt` if it is set.
// Otherwise, will return an `error_code` error that includes `name` in the
// error message.
absl::Status ValidateIsSet(
absl::optional<double> opt, absl::string_view name,
absl::StatusCode error_code = absl::StatusCode::kInvalidArgument);
// Returns absl::OkStatus() if the value of optional `opt` if it is set and
// positive. Otherwise, will return an `error_code` error status that includes
// `name` in the error message.
absl::Status ValidateIsPositive(
absl::optional<double> opt, absl::string_view name,
absl::StatusCode error_code = absl::StatusCode::kInvalidArgument);
// Returns absl::OkStatus() if the value of optional `opt` if it is set and
// non-negative. Otherwise, will return an `error_code` error status that
// includes `name` in the error message.
absl::Status ValidateIsNonNegative(
absl::optional<double> opt, absl::string_view name,
absl::StatusCode error_code = absl::StatusCode::kInvalidArgument);
// Returns absl::OkStatus() if the value of optional `opt` if it is set and
// finite. Otherwise, will return an `error_code` error status that includes
// `name` in the error message.
absl::Status ValidateIsFinite(
absl::optional<double> opt, absl::string_view name,
absl::StatusCode error_code = absl::StatusCode::kInvalidArgument);
// Returns absl::OkStatus() if the value of optional `opt` if it is set, finite,
// and positive. Otherwise, will return an `error_code` error status that
// includes `name` in the error message.
absl::Status ValidateIsFiniteAndPositive(
absl::optional<double> opt, absl::string_view name,
absl::StatusCode error_code = absl::StatusCode::kInvalidArgument);
// Returns absl::OkStatus() if the value of optional `opt` if it is set, finite,
// and non-negative. Otherwise, will return an `error_code` error status that
// includes `name` in the error message.
absl::Status ValidateIsFiniteAndNonNegative(
absl::optional<double> opt, absl::string_view name,
absl::StatusCode error_code = absl::StatusCode::kInvalidArgument);
// Returns absl::OkStatus() if the value of optional `opt` if it is set and
// within the inclusive (i.e., closed) interval [`lower_bound`, `upper_bound`].
// Otherwise, will return an `error_code` error status that includes `name` in
// the error message.
absl::Status ValidateIsInInclusiveInterval(
absl::optional<double> opt, double lower_bound, double upper_bound,
absl::string_view name,
absl::StatusCode error_code = absl::StatusCode::kInvalidArgument);
// Returns absl::OkStatus() if the value of optional `opt` if it is set and
// within the exclusive (i.e., open) interval (`lower_bound`, `upper_bound`).
// Otherwise, will return an `error_code` error status that includes `name` in
// the error message.
absl::Status ValidateIsInExclusiveInterval(
absl::optional<double> opt, double lower_bound, double upper_bound,
absl::string_view name,
absl::StatusCode error_code = absl::StatusCode::kInvalidArgument);
// Returns absl::OkStatus() if the value of optional `opt` if it is set and
// strictly lesser than `upper_bound`. Otherwise, will return an `error_code`
// error status that includes `name` in the error message.
absl::Status ValidateIsLesserThan(
absl::optional<double> opt, double upper_bound, absl::string_view name,
absl::StatusCode error_code = absl::StatusCode::kInvalidArgument);
// Returns absl::OkStatus() if the value of optional `opt` if it is set and
// lesser than or equal to `upper_bound`. Otherwise, will return an `error_code`
// error status that includes `name` in the error message.
absl::Status ValidateIsLesserThanOrEqualTo(
absl::optional<double> opt, double upper_bound, absl::string_view name,
absl::StatusCode error_code = absl::StatusCode::kInvalidArgument);
// Returns absl::OkStatus() if the value of optional `opt` if it is set and
// strictly greater than `lower_bound`. Otherwise, will return an `error_code`
// error status that includes `name` in the error message.
absl::Status ValidateIsGreaterThan(
absl::optional<double> opt, double lower_bound, absl::string_view name,
absl::StatusCode error_code = absl::StatusCode::kInvalidArgument);
// Returns absl::OkStatus() if the value of optional `opt` if it is set and
// greater than or equal to `upper_bound`. Otherwise, will return an
// `error_code` error status that includes `name` in the error message.
absl::Status ValidateIsGreaterThanOrEqualTo(
absl::optional<double> opt, double lower_bound, absl::string_view name,
absl::StatusCode error_code = absl::StatusCode::kInvalidArgument);
// Returns absl::OkStatus() if the value of optional `opt` if it is set and
// within the interval between `lower_bound` and `upper_bound`, including
// `lower_bound` and/or `upper_bound` if `include_lower` or `include_upper` are
// true, respectively. Otherwise, will return an `error_code` error status that
// includes `name` in the error message.
absl::Status ValidateIsInInterval(
absl::optional<double> opt, double lower_bound, double upper_bound,
bool include_lower, bool include_upper, absl::string_view name,
absl::StatusCode error_code = absl::StatusCode::kInvalidArgument);
// Methods for semantical and consistent validation of common parameters.
absl::Status ValidateEpsilon(absl::optional<double> epsilon);
absl::Status ValidateDelta(absl::optional<double> delta);
absl::Status ValidateMaxPartitionsContributed(
absl::optional<double> max_partitions_contributed);
absl::Status ValidateMaxContributionsPerPartition(
absl::optional<double> max_contributions_per_partition);
absl::Status ValidateMaxContributions(absl::optional<int> max_contributions);
// Validates common tree parameters.
absl::Status ValidateTreeHeight(absl::optional<int> tree_height);
absl::Status ValidateBranchingFactor(absl::optional<int> branching_factor);
template <typename T>
absl::Status ValidateBounds(absl::optional<T> lower, absl::optional<T> upper) {
if (!lower.has_value() && !upper.has_value()) {
return absl::OkStatus();
}
if (lower.has_value() != upper.has_value()) {
return absl::InvalidArgumentError(
"Lower and upper bounds must either both be set or both be unset.");
}
RETURN_IF_ERROR(ValidateIsFinite(lower.value(), "Lower bound"));
RETURN_IF_ERROR(ValidateIsFinite(upper.value(), "Upper bound"));
if (lower.value() > upper.value()) {
return absl::InvalidArgumentError(
"Lower bound cannot be greater than upper bound.");
}
return absl::OkStatus();
}
} // namespace differential_privacy
#endif // DIFFERENTIAL_PRIVACY_ALGORITHMS_UTIL_H_