| // Copyright 2024 The Fuchsia Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "src/algorithms/privacy/rmse.h" |
| |
| #include <gmock/gmock.h> |
| #include <gtest/gtest.h> |
| |
| #include "src/algorithms/privacy/numeric_encoding.h" |
| |
| using ::testing::DoubleEq; |
| |
| namespace cobalt { |
| |
| TEST(RmseTest, Unary) { |
| double lambda = 9.0; |
| |
| EXPECT_THAT(PoissonUnaryRmse(lambda), DoubleEq(3.0)); |
| } |
| |
| namespace { |
| |
| double PoissonSumRmseSlow(double lambda, double min_value, double max_value, |
| uint64_t num_index_points) { |
| double expected_mse = 0; |
| |
| for (uint64_t i = 0; i < num_index_points; i++) { |
| double xi = cobalt::DoubleFromIndex(i, min_value, max_value, num_index_points); |
| expected_mse += lambda * xi * xi; |
| } |
| return std::sqrt(expected_mse); |
| } |
| |
| } // namespace |
| |
| TEST(RmseTest, PoissonSumRmseSizeOneIntervals) { |
| double lambda = 17.82; |
| double min_value = 0; |
| double max_value = 10; |
| uint64_t num_index_points = 11; |
| |
| double expected_rmse = PoissonSumRmseSlow(lambda, min_value, max_value, num_index_points); |
| |
| EXPECT_THAT(expected_rmse, |
| DoubleEq(PoissonSumRmse(lambda, min_value, max_value, num_index_points))); |
| } |
| |
| TEST(RmseTest, PoissonSumRmseSizeLargerIntervals) { |
| double lambda = 17.82; |
| double min_value = 0; |
| double max_value = 25; |
| uint64_t num_index_points = 11; |
| |
| double expected_rmse = PoissonSumRmseSlow(lambda, min_value, max_value, num_index_points); |
| |
| EXPECT_THAT(expected_rmse, |
| DoubleEq(PoissonSumRmse(lambda, min_value, max_value, num_index_points))); |
| } |
| |
| TEST(RmseTest, PoissonSumRmseSizeNonZeroMin) { |
| double lambda = 17.82; |
| double min_value = 22; |
| double max_value = 51; |
| uint64_t num_index_points = 11; |
| |
| double expected_rmse = PoissonSumRmseSlow(lambda, min_value, max_value, num_index_points); |
| |
| EXPECT_THAT(expected_rmse, |
| DoubleEq(PoissonSumRmse(lambda, min_value, max_value, num_index_points))); |
| } |
| |
| } // namespace cobalt |