Improve numerical stability of RDP computation.

PiperOrigin-RevId: 628482307
Change-Id: If51bed053a057b782c534ced37b13aea1b9bc96d
GitOrigin-RevId: b3467a5f002241d940484761cc0b4263885f84cd
diff --git a/python/dp_accounting/dp_accounting/rdp/BUILD.bazel b/python/dp_accounting/dp_accounting/rdp/BUILD.bazel
index f7a07d2..c711f7b 100644
--- a/python/dp_accounting/dp_accounting/rdp/BUILD.bazel
+++ b/python/dp_accounting/dp_accounting/rdp/BUILD.bazel
@@ -36,6 +36,7 @@
     deps = [
         "//dp_accounting:dp_event",
         "//dp_accounting:privacy_accountant",
+        requirement("absl-py"),
         requirement("numpy"),
         requirement("scipy"),
     ],
diff --git a/python/dp_accounting/dp_accounting/rdp/rdp_privacy_accountant.py b/python/dp_accounting/dp_accounting/rdp/rdp_privacy_accountant.py
index 9d2b9a3..1a5f809 100644
--- a/python/dp_accounting/dp_accounting/rdp/rdp_privacy_accountant.py
+++ b/python/dp_accounting/dp_accounting/rdp/rdp_privacy_accountant.py
@@ -17,6 +17,7 @@
 import math
 from typing import Callable, Optional, Sequence, Tuple, Union
 
+from absl import logging
 import numpy as np
 from scipy import special
 
@@ -38,30 +39,14 @@
   return math.log1p(math.exp(a - b)) + b  # log1p(x) = log(x + 1)
 
 
-def _log_sub(logx: float, logy: float) -> float:
-  """Subtracts two numbers in the log space. Answer must be non-negative."""
-  if logx < logy:
-    raise ValueError('The result of subtraction must be non-negative.')
-  if logy == -np.inf:  # subtracting 0
-    return logx
-  if logx == logy:
-    return -np.inf  # 0 is represented as -np.inf in the log space.
-
-  try:
-    # Use exp(x) - exp(y) = (exp(x - y) - 1) * exp(y).
-    return math.log(math.expm1(logx - logy)) + logy  # expm1(x) = exp(x) - 1
-  except OverflowError:
-    return logx
-
-
 def _log_sub_sign(logx: float, logy: float) -> Tuple[bool, float]:
   """Returns log(exp(logx)-exp(logy)) and its sign."""
   if logx > logy:
     s = True
-    mag = logx + np.log(1 - np.exp(logy - logx))
+    mag = logx + math.log1p(-np.exp(logy - logx))
   elif logx < logy:
     s = False
-    mag = logy + np.log(1 - np.exp(logx - logy))
+    mag = logy + np.log1p(-np.exp(logx - logy))
   else:
     s = True
     mag = -np.inf
@@ -69,7 +54,7 @@
   return s, mag
 
 
-def _log_comb(n: int, k: int) -> float:
+def _log_comb(n: float, k: float) -> float:
   """Computes log of binomial coefficient."""
   return (special.gammaln(n + 1) - special.gammaln(k + 1) -
           special.gammaln(n - k + 1))
@@ -80,15 +65,14 @@
 
   # Initialize with 0 in the log space.
   log_a = -np.inf
+  log1mq = math.log1p(-q)
 
   for i in range(alpha + 1):
-    log_coef_i = (
-        _log_comb(alpha, i) + i * math.log(q) + (alpha - i) * math.log(1 - q))
-
+    log_coef_i = _log_comb(alpha, i) + i * math.log(q) + (alpha - i) * log1mq
     s = log_coef_i + (i * i - i) / (2 * (sigma**2))
     log_a = _log_add(log_a, s)
 
-  return float(log_a)
+  return log_a
 
 
 def _compute_log_a_frac(q: float, sigma: float, alpha: float) -> float:
@@ -96,17 +80,16 @@
   # The two parts of A_alpha, integrals over (-inf,z0] and [z0, +inf), are
   # initialized to 0 in the log space:
   log_a0, log_a1 = -np.inf, -np.inf
-  i = 0
-
   z0 = sigma**2 * math.log(1 / q - 1) + .5
+  log1mq = math.log1p(-q)
 
+  i = 0
   while True:  # do ... until loop
-    coef = special.binom(alpha, i)
-    log_coef = math.log(abs(coef))
+    log_coef = _log_comb(alpha, i)
     j = alpha - i
 
-    log_t0 = log_coef + i * math.log(q) + j * math.log(1 - q)
-    log_t1 = log_coef + j * math.log(q) + i * math.log(1 - q)
+    log_t0 = log_coef + i * math.log(q) + j * log1mq
+    log_t1 = log_coef + j * math.log(q) + i * log1mq
 
     log_e0 = math.log(.5) + _log_erfc((i - z0) / (math.sqrt(2) * sigma))
     log_e1 = math.log(.5) + _log_erfc((z0 - j) / (math.sqrt(2) * sigma))
@@ -114,12 +97,8 @@
     log_s0 = log_t0 + (i * i - i) / (2 * (sigma**2)) + log_e0
     log_s1 = log_t1 + (j * j - j) / (2 * (sigma**2)) + log_e1
 
-    if coef > 0:
-      log_a0 = _log_add(log_a0, log_s0)
-      log_a1 = _log_add(log_a1, log_s1)
-    else:
-      log_a0 = _log_sub(log_a0, log_s0)
-      log_a1 = _log_sub(log_a1, log_s1)
+    log_a0 = _log_add(log_a0, log_s0)
+    log_a1 = _log_add(log_a1, log_s1)
 
     i += 1
     if max(log_s0, log_s1) < -30:
@@ -177,11 +156,16 @@
     if a < 1:
       raise ValueError(f'Renyi divergence order must be at least 1. Found {a}.')
     if r < 0:
-      raise ValueError(f'Renyi divergence cannot be negative. Found {r}.')
+      logging.warning(
+          'Negative Renyi divergence of %s, probably caused by numerical '
+          'instability from extreme DpEvents. Returning a delta of zero.',
+          r,
+      )
+      logdelta = -np.inf
     # For small alpha, we are better of with bound via KL divergence:
     # delta <= sqrt(1-exp(-KL)).
     # Take a min of the two bounds.
-    if r == 0:
+    elif r == 0:
       logdelta = -np.inf
     else:
       logdelta = 0.5 * math.log1p(-math.exp(-r))
@@ -237,9 +221,13 @@
     if a < 1:
       raise ValueError(f'Renyi divergence order must be at least 1. Found {a}.')
     if r < 0:
-      raise ValueError(f'Renyi divergence cannot be negative. Found {r}.')
-
-    if delta**2 + math.expm1(-r) > 0:
+      logging.warning(
+          'Negative Renyi divergence of %s, probably caused by numerical '
+          'instability from extreme DpEvents. Returning an epsilon of zero.',
+          r,
+      )
+      epsilon = 0.0
+    elif delta**2 + math.expm1(-r) > 0:
       # In this case, we can simply bound via KL divergence:
       # delta <= sqrt(1-exp(-KL)).
       epsilon = 0  # No need to try further computation if we have epsilon = 0.
@@ -468,7 +456,7 @@
   # Initialize with 1 in the log space.
   log_a = 0
   # Calculates the log term when alpha = 2
-  log_f2m1 = func(2.0) + np.log(1 - np.exp(-func(2.0)))
+  log_f2m1 = func(2.0) + math.log1p(-np.exp(-func(2.0)))
   if alpha <= max_alpha:
     # We need forward differences of exp(cgf)
     # The following line is the numerically stable way of implementing it.
@@ -777,13 +765,13 @@
 
 def _laplace_rdp(eps: float, order: float) -> float:
   """Computes RDP of Laplace noise addition.
-  
+
   See Proposition 6 of Mironov (2017): https://arxiv.org/abs/1702.07476
   RDP = eps + log[ 1 + (a-1) * [exp( (1-2*a) * eps) - 1] / (2*a-1) ] / (a-1).
   In contrast, the naive bound is RDP <= eps, which is tight as order->infinity.
   For small order & eps, we can do better: In general, RDP <= order * eps^2 / 2.
   The above formula is exactly tight for the Laplace mechanism.
-  
+
   Args:
     eps: The pure DP guarantee corresponding to the limit order->infinity.
     order: The Renyi divergence order (a.k.a. alpha).
diff --git a/python/dp_accounting/dp_accounting/rdp/rdp_privacy_accountant_test.py b/python/dp_accounting/dp_accounting/rdp/rdp_privacy_accountant_test.py
index f22d3e9..bc6727d 100644
--- a/python/dp_accounting/dp_accounting/rdp/rdp_privacy_accountant_test.py
+++ b/python/dp_accounting/dp_accounting/rdp/rdp_privacy_accountant_test.py
@@ -263,30 +263,41 @@
     count = 50
     event = dp_event.SelfComposedDpEvent(
         dp_event.PoissonSampledDpEvent(
-            sampling_probability, dp_event.GaussianDpEvent(noise_multiplier)),
-        count)
+            sampling_probability, dp_event.GaussianDpEvent(noise_multiplier)
+        ),
+        count,
+    )
     accountant = rdp_privacy_accountant.RdpAccountant(orders=orders)
     accountant.compose(event)
     self.assertTrue(
         np.allclose(
-            accountant._rdp, [
-                6.5007e-04, 1.0854e-03, 2.1808e-03, 2.3846e-02, 1.6742e+02,
-                np.inf
+            accountant._rdp,
+            [
+                6.70579741e-04,
+                1.08548710e-03,
+                2.18075656e-03,
+                2.38460375e-02,
+                1.67416308e02,
+                np.inf,
             ],
-            rtol=1e-4))
+            rtol=1e-4,
+        )
+    )
 
   def test_compute_epsilon_delta_pure_dp(self):
     orders = range(2, 33)
     rdp = [1.1 for _ in orders]  # Constant corresponds to pure DP.
 
     epsilon, optimal_order = rdp_privacy_accountant.compute_epsilon(
-        orders, rdp, delta=1e-5)
+        orders, rdp, delta=1e-5
+    )
     # Compare with epsilon computed by hand.
     self.assertAlmostEqual(epsilon, 1.32783806176)
     self.assertEqual(optimal_order, 32)
 
     delta, optimal_order = rdp_privacy_accountant.compute_delta(
-        orders, rdp, epsilon=1.32783806176)
+        orders, rdp, epsilon=1.32783806176
+    )
     self.assertAlmostEqual(delta, 1e-5)
     self.assertEqual(optimal_order, 32)
 
@@ -352,7 +363,7 @@
     # computation.
     log_a = rdp_privacy_accountant._compute_log_a(q, sigma, order)
     log_a_mp = _log_float_mp(_compute_a_mp(sigma, q, order))
-    np.testing.assert_allclose(log_a, log_a_mp, rtol=1e-4)
+    np.testing.assert_allclose(log_a, log_a_mp, rtol=1e-4, atol=1e-8)
 
   def test_delta_bounds_gaussian(self):
     # Compare the optimal bound for Gaussian with the one derived from RDP.