#include "src/logger/privacy_encoder.h"

#include <gtest/gtest.h>

#include "src/algorithms/random/test_secure_random.h"
#include "src/lib/statusor/statusor.h"
#include "src/pb/observation.pb.h"
#include "src/registry/metric_definition.pb.h"
#include "src/registry/report_definition.pb.h"

namespace cobalt::logger {

class PrivacyEncoderTest : public testing::Test {
 protected:
  void SetUp() override {
    auto secure_gen = std::make_unique<TestSecureRandomNumberGenerator>(0);
    auto gen = std::make_unique<RandomNumberGenerator>(0);
    privacy_encoder_ = std::make_unique<PrivacyEncoder>(std::move(secure_gen), std::move(gen));
  }

  [[nodiscard]] PrivacyEncoder *GetPrivacyEncoder() const { return privacy_encoder_.get(); }

  static lib::statusor::StatusOr<std::vector<uint64_t>> PrepareIndexVectorForUniqueDeviceCount(
      const Observation &observation, const MetricDefinition &metric_def) {
    return PrivacyEncoder::PrepareIndexVectorForUniqueDeviceCount(observation, metric_def);
  }

  lib::statusor::StatusOr<std::vector<uint64_t>> PrepareIndexVectorForPerDeviceIntegerReport(
      const Observation &observation, const MetricDefinition &metric_def,
      const ReportDefinition &report_def) {
    return privacy_encoder_->PrepareIndexVectorForPerDeviceIntegerReport(observation, metric_def,
                                                                         report_def);
  }

  lib::statusor::StatusOr<std::vector<uint64_t>> PrepareIndexVectorForFleetwideMeansReport(
      const Observation &observation, const MetricDefinition &metric_def,
      const ReportDefinition &report_def) {
    return privacy_encoder_->PrepareIndexVectorForFleetwideMeansReport(observation, metric_def,
                                                                       report_def);
  }

  lib::statusor::StatusOr<std::vector<uint64_t>> PrepareIndexVectorForPerDeviceHistogramsReport(
      const Observation &observation, const MetricDefinition &metric_def,
      const ReportDefinition &report_def) {
    return privacy_encoder_->PrepareIndexVectorForPerDeviceHistogramsReport(observation, metric_def,
                                                                            report_def);
  }

  static std::vector<std::unique_ptr<Observation>> ObservationsFromIndices(
      const std::vector<uint64_t> &indices) {
    return PrivacyEncoder::ObservationsFromIndices(indices);
  }

  lib::statusor::StatusOr<std::vector<uint64_t>> AddNoise(const std::vector<uint64_t> &indices,
                                                          const MetricDefinition &metric_def,
                                                          const ReportDefinition &report_def) {
    return privacy_encoder_->AddNoise(indices, metric_def, report_def);
  }

  lib::statusor::StatusOr<std::vector<std::unique_ptr<Observation>>> MakePrivateObservations(
      const Observation *observation, const MetricDefinition &metric_def,
      const ReportDefinition &report_def) {
    return privacy_encoder_->MakePrivateObservations(observation, metric_def, report_def);
  }

 private:
  std::unique_ptr<PrivacyEncoder> privacy_encoder_;
};

TEST_F(PrivacyEncoderTest, MaybeMakePrivateObservationsNoAddedPrivacyReport) {
  MetricDefinition metric_def;
  ReportDefinition report_def;
  report_def.set_privacy_level(ReportDefinition::NO_ADDED_PRIVACY);
  auto observation = std::make_unique<Observation>();
  Observation *expected = observation.get();
  auto status_or = GetPrivacyEncoder()->MaybeMakePrivateObservations(std::move(observation),
                                                                     metric_def, report_def);
  ASSERT_TRUE(status_or.ok());
  auto observations = std::move(status_or.ValueOrDie());
  ASSERT_EQ(observations.size(), 1u);
  EXPECT_EQ(observations[0].get(), expected);
}

TEST_F(PrivacyEncoderTest, MaybeMakePrivateObservationsNoAddedPrivacyReportNoAggregateData) {
  MetricDefinition metric_def;
  ReportDefinition report_def;
  report_def.set_privacy_level(ReportDefinition::NO_ADDED_PRIVACY);
  lib::statusor::StatusOr<std::vector<std::unique_ptr<Observation>>> status_or =
      GetPrivacyEncoder()->MaybeMakePrivateObservations(nullptr, metric_def, report_def);
  ASSERT_TRUE(status_or.ok());
  std::vector<std::unique_ptr<Observation>> observations = status_or.ConsumeValueOrDie();
  ASSERT_EQ(observations.size(), 1u);
  EXPECT_EQ(observations[0].get(), nullptr);
}

TEST_F(PrivacyEncoderTest, MakePrivateObservationsNoAddedPrivacyReport) {
  MetricDefinition metric_def;
  ReportDefinition report_def;
  report_def.set_privacy_level(ReportDefinition::NO_ADDED_PRIVACY);
  Observation observation;
  auto status_or = MakePrivateObservations(&observation, metric_def, report_def);
  ASSERT_FALSE(status_or.ok());
  EXPECT_EQ(status_or.status().error_code(), util::INVALID_ARGUMENT);
}

TEST_F(PrivacyEncoderTest, MakePrivateObservationsImplemented) {
  std::vector<ReportDefinition::ReportType> implemented_report_types = {
      ReportDefinition::UNIQUE_DEVICE_COUNTS,
      ReportDefinition::FLEETWIDE_OCCURRENCE_COUNTS,
      ReportDefinition::HOURLY_VALUE_NUMERIC_STATS,
      ReportDefinition::UNIQUE_DEVICE_NUMERIC_STATS,
      ReportDefinition::FLEETWIDE_MEANS,
      ReportDefinition::HOURLY_VALUE_HISTOGRAMS,
      ReportDefinition::UNIQUE_DEVICE_HISTOGRAMS};

  MetricDefinition metric_def;
  ReportDefinition report_def;
  report_def.set_privacy_level(ReportDefinition::LOW_PRIVACY);
  Observation observation;

  for (const ReportDefinition::ReportType report_type : implemented_report_types) {
    report_def.set_report_type(report_type);
    EXPECT_NE(MakePrivateObservations(&observation, metric_def, report_def).status().error_code(),
              util::UNIMPLEMENTED);
  }
}

// MakePrivateObservations() is not implemented yet for these report types. Move report types to the
// MakePrivateObservationsImplemented test as they are implemented.
TEST_F(PrivacyEncoderTest, MakePrivateObservationsUnimplemented) {
  std::vector<ReportDefinition::ReportType> unimplemented_report_types = {
      ReportDefinition::FLEETWIDE_HISTOGRAMS, ReportDefinition::STRING_COUNTS};

  MetricDefinition metric_def;
  ReportDefinition report_def;
  report_def.set_privacy_level(ReportDefinition::LOW_PRIVACY);

  Observation observation;
  for (const ReportDefinition::ReportType report_type : unimplemented_report_types) {
    report_def.set_report_type(report_type);
    EXPECT_EQ(MakePrivateObservations(&observation, metric_def, report_def).status().error_code(),
              util::UNIMPLEMENTED);
  }
}

TEST_F(PrivacyEncoderTest, MakePrivateObservationsNullObservation) {
  MetricDefinition metric_def;
  ReportDefinition report_def;
  report_def.set_privacy_level(ReportDefinition::LOW_PRIVACY);
  report_def.set_report_type(ReportDefinition::UNIQUE_DEVICE_COUNTS);
  auto status_or = MakePrivateObservations(nullptr, metric_def, report_def);
  ASSERT_TRUE(status_or.ok());
}

TEST_F(PrivacyEncoderTest, UniqueDeviceCount) {
  MetricDefinition metric_def;
  metric_def.set_metric_type(MetricDefinition::OCCURRENCE);
  auto metric_dim = metric_def.add_metric_dimensions();
  metric_dim->set_dimension("dimension 0");
  metric_dim->set_max_event_code(10);

  std::vector<uint64_t> expected_indices = {2, 4, 6};
  Observation observation;
  auto integer_obs = observation.mutable_integer();
  for (uint64_t index : expected_indices) {
    auto val = integer_obs->add_values();
    val->add_event_codes(index);
    val->set_value(1);
  }

  auto status_or_indices = PrepareIndexVectorForUniqueDeviceCount(observation, metric_def);
  ASSERT_TRUE(status_or_indices.ok());
  EXPECT_EQ(status_or_indices.ValueOrDie(), expected_indices);
}

TEST_F(PrivacyEncoderTest, UniqueDeviceCountInvalidObservationType) {
  MetricDefinition metric_def;
  metric_def.set_metric_type(MetricDefinition::OCCURRENCE);
  auto metric_dim = metric_def.add_metric_dimensions();
  metric_dim->set_dimension("dimension 0");
  metric_dim->set_max_event_code(10);

  Observation observation;

  auto status_or_indices = PrepareIndexVectorForUniqueDeviceCount(observation, metric_def);
  ASSERT_FALSE(status_or_indices.ok());
  EXPECT_EQ(status_or_indices.status().error_code(), util::INVALID_ARGUMENT);
}

TEST_F(PrivacyEncoderTest, PerDeviceInteger) {
  MetricDefinition metric_def;
  metric_def.set_metric_type(MetricDefinition::OCCURRENCE);
  auto metric_dim = metric_def.add_metric_dimensions();
  metric_dim->set_dimension("dimension 0");
  metric_dim->set_max_event_code(10);

  ReportDefinition report_def;
  report_def.set_min_value(-4);
  report_def.set_max_value(6);
  report_def.set_num_index_points(6);

  std::vector<uint64_t> expected_indices = {14, 40};
  Observation observation;
  auto integer_obs = observation.mutable_integer();

  auto val = integer_obs->add_values();
  val->add_event_codes(3u);
  val->set_value(-2);

  val = integer_obs->add_values();
  val->add_event_codes(7u);
  val->set_value(2);

  auto status_or_indices =
      PrepareIndexVectorForPerDeviceIntegerReport(observation, metric_def, report_def);
  ASSERT_TRUE(status_or_indices.ok());
  EXPECT_EQ(status_or_indices.ValueOrDie(), expected_indices);
}

TEST_F(PrivacyEncoderTest, PerDeviceIntegerInvalidObservationType) {
  MetricDefinition metric_def;
  ReportDefinition report_def;
  Observation observation;

  auto status_or_indices =
      PrepareIndexVectorForPerDeviceIntegerReport(observation, metric_def, report_def);
  ASSERT_FALSE(status_or_indices.ok());
  EXPECT_EQ(status_or_indices.status().error_code(), util::INVALID_ARGUMENT);
}

TEST_F(PrivacyEncoderTest, FleetwideMeans) {
  MetricDefinition metric_def;
  metric_def.set_metric_type(MetricDefinition::OCCURRENCE);
  auto metric_dim = metric_def.add_metric_dimensions();
  metric_dim->set_dimension("dimension 0");
  metric_dim->set_max_event_code(10);

  ReportDefinition report_def;
  report_def.set_min_value(-4);
  report_def.set_max_value(6);
  report_def.set_num_index_points(6);
  report_def.set_max_count(10);

  std::vector<uint64_t> expected_indices = {14, 91};
  Observation observation;
  auto sum_and_count_obs = observation.mutable_sum_and_count();

  auto val = sum_and_count_obs->add_sums_and_counts();
  val->add_event_codes(3u);
  val->set_sum(-2);
  val->set_count(4);

  auto status_or_indices =
      PrepareIndexVectorForFleetwideMeansReport(observation, metric_def, report_def);
  ASSERT_TRUE(status_or_indices.ok());
  EXPECT_EQ(status_or_indices.ValueOrDie(), expected_indices);
}

TEST_F(PrivacyEncoderTest, FleetwideMeansInvalidObservationType) {
  MetricDefinition metric_def;
  ReportDefinition report_def;
  Observation observation;

  auto status_or_indices =
      PrepareIndexVectorForFleetwideMeansReport(observation, metric_def, report_def);
  ASSERT_FALSE(status_or_indices.ok());
  EXPECT_EQ(status_or_indices.status().error_code(), util::INVALID_ARGUMENT);
}

TEST_F(PrivacyEncoderTest, PerDeviceHistograms) {
  // |metric_def| has 11 valid event vectors.
  MetricDefinition metric_def;
  metric_def.set_metric_type(MetricDefinition::OCCURRENCE);
  auto metric_dim = metric_def.add_metric_dimensions();
  metric_dim->set_dimension("dimension 0");
  metric_dim->set_max_event_code(10);

  // Counting the underflow and overflow buckets, |int_buckets| has 8 + 2 = 10 valid bucket indices.
  LinearIntegerBuckets int_buckets;
  int_buckets.set_floor(0);
  int_buckets.set_num_buckets(8);
  int_buckets.set_step_size(2);

  ReportDefinition report_def;
  *report_def.mutable_int_buckets()->mutable_linear() = int_buckets;

  std::vector<uint64_t> expected_indices = {14, 40};
  Observation observation;
  auto integer_obs = observation.mutable_integer();

  // Add a value for event code 3 and bucket index 1.
  // The expected index is:
  // num_event_vectors * bucket_index + event_code = 11 * 1 + 3 = 14.
  auto val = integer_obs->add_values();
  val->add_event_codes(3u);
  val->set_value(0);

  // Add a value for event code 7 and bucket index 3.
  // The expected index is:
  // num_event_vectors * bucket_index + event_code = 11 * 3 + 7 = 40.
  val = integer_obs->add_values();
  val->add_event_codes(7u);
  val->set_value(5);

  auto status_or_indices =
      PrepareIndexVectorForPerDeviceHistogramsReport(observation, metric_def, report_def);
  ASSERT_TRUE(status_or_indices.ok());
  EXPECT_EQ(status_or_indices.ValueOrDie(), expected_indices);
}

TEST_F(PrivacyEncoderTest, PerDeviceHistogramsInvalidObservationType) {
  MetricDefinition metric_def;
  ReportDefinition report_def;
  Observation observation;

  auto status_or_indices =
      PrepareIndexVectorForPerDeviceHistogramsReport(observation, metric_def, report_def);
  ASSERT_FALSE(status_or_indices.ok());
  EXPECT_EQ(status_or_indices.status().error_code(), util::INVALID_ARGUMENT);
}

TEST_F(PrivacyEncoderTest, ObservationsFromIndicesNoIndices) {
  std::vector<uint64_t> indices;
  auto observations = ObservationsFromIndices(indices);
  // Expect one ReportParticipationObservation.
  ASSERT_EQ(observations.size(), 1u);
  EXPECT_EQ(observations[0]->observation_type_case(), Observation::kReportParticipation);
}

TEST_F(PrivacyEncoderTest, ObservationsFromIndices) {
  std::vector<uint64_t> indices = {1, 2, 20, 50};
  auto observations = ObservationsFromIndices(indices);
  // Expect 1 PrivateIndexObservation for each index, plus one ReportParticipationObservation.
  ASSERT_EQ(observations.size(), indices.size() + 1);
  for (size_t i = 0; i < observations.size() - 1; ++i) {
    ASSERT_EQ(observations[i]->observation_type_case(), Observation::kPrivateIndex);
    EXPECT_EQ(observations[i]->private_index().index(), indices[i]);
  }
  EXPECT_EQ(observations[observations.size() - 1]->observation_type_case(),
            Observation::kReportParticipation);
}

TEST_F(PrivacyEncoderTest, MaxIndexForReportUniqueDeviceCount) {
  MetricDefinition metric_def;
  uint32_t max_event_code = 10;
  auto dim = metric_def.add_metric_dimensions();
  dim->set_dimension("dimension 1");
  dim->set_max_event_code(max_event_code);

  ReportDefinition report_def;
  report_def.set_report_type(ReportDefinition::UNIQUE_DEVICE_COUNTS);

  auto status_or = PrivacyEncoder::MaxIndexForReport(metric_def, report_def);
  ASSERT_TRUE(status_or.ok());

  uint64_t expected_max_index = max_event_code;
  EXPECT_EQ(status_or.ValueOrDie(), expected_max_index);
}

TEST_F(PrivacyEncoderTest, MaxIndexForReportPerDeviceInteger) {
  MetricDefinition metric_def;
  uint32_t max_event_code = 10;
  auto dim = metric_def.add_metric_dimensions();
  dim->set_dimension("dimension 1");
  dim->set_max_event_code(max_event_code);

  std::vector<ReportDefinition::ReportType> per_device_integer_report_types = {
      ReportDefinition::FLEETWIDE_OCCURRENCE_COUNTS, ReportDefinition::HOURLY_VALUE_NUMERIC_STATS,
      ReportDefinition::UNIQUE_DEVICE_NUMERIC_STATS};

  ReportDefinition report_def;
  uint32_t num_index_points = 5;
  report_def.set_num_index_points(num_index_points);
  // There are 11 event codes and 5 index points for a total of 55 possible indices.
  uint64_t expected_max_index = 54;

  for (auto report_type : per_device_integer_report_types) {
    report_def.set_report_type(report_type);

    auto status_or = PrivacyEncoder::MaxIndexForReport(metric_def, report_def);
    ASSERT_TRUE(status_or.ok());

    EXPECT_EQ(expected_max_index, status_or.ValueOrDie());
  }
}

TEST_F(PrivacyEncoderTest, MaxIndexForReportFleetwideMeans) {
  MetricDefinition metric_def;
  uint32_t max_event_code = 10;
  auto dim = metric_def.add_metric_dimensions();
  dim->set_dimension("dimension 1");
  dim->set_max_event_code(max_event_code);

  ReportDefinition report_def;
  report_def.set_report_type(ReportDefinition::FLEETWIDE_MEANS);
  uint32_t num_index_points = 5;
  report_def.set_num_index_points(num_index_points);

  auto status_or = PrivacyEncoder::MaxIndexForReport(metric_def, report_def);
  ASSERT_TRUE(status_or.ok());

  uint64_t expected_max_index = (max_event_code + 1) * num_index_points * 2 - 1;
  EXPECT_EQ(status_or.ValueOrDie(), expected_max_index);
}

TEST_F(PrivacyEncoderTest, MaxIndexForReportPerDeviceHistogram) {
  MetricDefinition metric_def;
  uint32_t max_event_code = 10;
  auto dim = metric_def.add_metric_dimensions();
  dim->set_dimension("dimension 1");
  dim->set_max_event_code(max_event_code);

  std::vector<ReportDefinition::ReportType> per_device_histogram_report_types = {
      ReportDefinition::HOURLY_VALUE_HISTOGRAMS, ReportDefinition::UNIQUE_DEVICE_HISTOGRAMS};

  IntegerBuckets linear_buckets;
  linear_buckets.mutable_linear()->set_floor(0);
  linear_buckets.mutable_linear()->set_num_buckets(3);

  IntegerBuckets exp_buckets;
  exp_buckets.mutable_exponential()->set_floor(0);
  exp_buckets.mutable_exponential()->set_num_buckets(3);

  std::vector<IntegerBuckets> bucket_variants = {linear_buckets, exp_buckets};

  ReportDefinition report_def;

  // There are 11 event codes and 5 histogram buckets (3 registered + 1 underflow + 1 overflow) for
  // a total of 55 possible indices.
  uint64_t expected_max_index = 54;

  for (const auto report_type : per_device_histogram_report_types) {
    for (const auto &int_buckets : bucket_variants) {
      report_def.set_report_type(report_type);
      *report_def.mutable_int_buckets() = int_buckets;

      auto max_index = PrivacyEncoder::MaxIndexForReport(metric_def, report_def);
      ASSERT_TRUE(max_index.ok()) << "Failed to get max index with status "
                                  << max_index.status().error_code() << ", "
                                  << max_index.status().error_message();

      EXPECT_EQ(expected_max_index, max_index.ValueOrDie());
    }
  }
}

TEST_F(PrivacyEncoderTest, MaxIndexForReportUnimplemented) {
  MetricDefinition metric_def;
  ReportDefinition report_def;
  report_def.set_report_type(ReportDefinition::CUSTOM_RAW_DUMP);

  auto status_or = PrivacyEncoder::MaxIndexForReport(metric_def, report_def);
  ASSERT_FALSE(status_or.ok());
  EXPECT_EQ(status_or.status().error_code(), util::UNIMPLEMENTED);
}

TEST_F(PrivacyEncoderTest, AddNoisePOutOfRange) {
  MetricDefinition metric_def;
  ReportDefinition report_def;
  report_def.set_report_type(ReportDefinition::UNIQUE_DEVICE_COUNTS);
  std::vector<uint64_t> indices;

  // Check that a negative value of p is rejected.
  double p = -1.0;
  report_def.set_prob_bit_flip(p);
  auto status_or = AddNoise(indices, metric_def, report_def);
  ASSERT_FALSE(status_or.ok());
  EXPECT_EQ(status_or.status().error_code(), util::INVALID_ARGUMENT);

  // Check that a value of p which is greater than 1 is rejected.
  p = 2.0;
  report_def.set_prob_bit_flip(p);
  status_or = AddNoise(indices, metric_def, report_def);
  ASSERT_FALSE(status_or.ok());
  EXPECT_EQ(status_or.status().error_code(), util::INVALID_ARGUMENT);
}

TEST_F(PrivacyEncoderTest, AddNoiseIndexOutOfRange) {
  MetricDefinition metric_def;
  ReportDefinition report_def;
  uint32_t max_event_code = 10;
  report_def.set_report_type(ReportDefinition::UNIQUE_DEVICE_COUNTS);
  auto dim = metric_def.add_metric_dimensions();
  dim->set_dimension("dimension 1");
  dim->set_max_event_code(max_event_code);
  std::vector<uint64_t> indices = {max_event_code + 1};

  double p = 1.0;
  report_def.set_prob_bit_flip(p);
  auto status_or = AddNoise(indices, metric_def, report_def);
  ASSERT_FALSE(status_or.ok());
  EXPECT_EQ(status_or.status().error_code(), util::INVALID_ARGUMENT);
}

TEST_F(PrivacyEncoderTest, AddNoise) {
  MetricDefinition metric_def;
  ReportDefinition report_def;
  uint32_t max_event_code = 10;
  report_def.set_report_type(ReportDefinition::UNIQUE_DEVICE_COUNTS);
  auto dim = metric_def.add_metric_dimensions();
  dim->set_dimension("dimension 1");
  dim->set_max_event_code(max_event_code);
  std::vector<uint64_t> indices = {1, 2, 3};

  double p = 1.0;
  report_def.set_prob_bit_flip(p);
  auto status_or = AddNoise(indices, metric_def, report_def);
  EXPECT_TRUE(status_or.ok());
}

}  // namespace cobalt::logger
