blob: d1b02329612a84b68d4b200d65a98b6e43f6378a [file] [log] [blame]
// Copyright 2024 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "src/algorithms/privacy/rmse.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "src/algorithms/privacy/numeric_encoding.h"
using ::testing::DoubleEq;
namespace cobalt {
TEST(RmseTest, Unary) {
double lambda = 9.0;
EXPECT_THAT(PoissonUnaryRmse(lambda), DoubleEq(3.0));
}
namespace {
double PoissonSumRmseSlow(double lambda, double min_value, double max_value,
uint64_t num_index_points) {
double expected_mse = 0;
for (uint64_t i = 0; i < num_index_points; i++) {
double xi = cobalt::DoubleFromIndex(i, min_value, max_value, num_index_points);
expected_mse += lambda * xi * xi;
}
return std::sqrt(expected_mse);
}
} // namespace
TEST(RmseTest, PoissonSumRmseSizeOneIntervals) {
double lambda = 17.82;
double min_value = 0;
double max_value = 10;
uint64_t num_index_points = 11;
double expected_rmse = PoissonSumRmseSlow(lambda, min_value, max_value, num_index_points);
EXPECT_THAT(expected_rmse,
DoubleEq(PoissonSumRmse(lambda, min_value, max_value, num_index_points)));
}
TEST(RmseTest, PoissonSumRmseSizeLargerIntervals) {
double lambda = 17.82;
double min_value = 0;
double max_value = 25;
uint64_t num_index_points = 11;
double expected_rmse = PoissonSumRmseSlow(lambda, min_value, max_value, num_index_points);
EXPECT_THAT(expected_rmse,
DoubleEq(PoissonSumRmse(lambda, min_value, max_value, num_index_points)));
}
TEST(RmseTest, PoissonSumRmseSizeNonZeroMin) {
double lambda = 17.82;
double min_value = 22;
double max_value = 51;
uint64_t num_index_points = 11;
double expected_rmse = PoissonSumRmseSlow(lambda, min_value, max_value, num_index_points);
EXPECT_THAT(expected_rmse,
DoubleEq(PoissonSumRmse(lambda, min_value, max_value, num_index_points)));
}
} // namespace cobalt