// Copyright 2020 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include <gtest/gtest.h>

#include "src/lib/util/datetime_util.h"
#include "src/local_aggregation_1.1/aggregation_procedures/aggregation_procedure.h"
#include "src/local_aggregation_1.1/aggregation_procedures/testing/test_aggregation_procedure.h"
#include "src/local_aggregation_1.1/testing/test_registry.cb.h"
#include "src/logger/project_context_factory.h"
#include "src/pb/metadata_builder.h"

namespace cobalt::local_aggregation {

class IntegerHistogramAggregationProcedureTest : public testing::TestAggregationProcedure {
 public:
  void LogIntegerEvents(uint32_t hour_id, uint32_t num_event_codes, AggregationProcedure* procedure,
                        ReportAggregate* aggregate) {
    logger::EventRecord record = MakeEventRecord(util::TimeInfo::FromHourId(hour_id));
    IntegerEvent* event = record.event()->mutable_integer_event();
    event->add_event_code(0);
    for (int i = 0; i < num_event_codes; i++) {
      event->set_event_code(0, i);
      event->set_value(i);
      procedure->UpdateAggregate(record, aggregate);
    }
  }

  void LogIntegerHistogramEvents(uint32_t hour_id, uint32_t num_event_codes,
                                 const std::map<uint32_t, uint64_t>& histogram,
                                 AggregationProcedure* procedure, ReportAggregate* aggregate) {
    logger::EventRecord record = MakeEventRecord(util::TimeInfo::FromHourId(hour_id));
    IntegerHistogramEvent* event = record.event()->mutable_integer_histogram_event();
    for (auto [index, count] : histogram) {
      HistogramBucket* bucket = event->add_buckets();
      bucket->set_index(index);
      bucket->set_count(count);
    }
    event->add_event_code(0);
    for (int i = 0; i < num_event_codes; i++) {
      event->set_event_code(0, i);
      procedure->UpdateAggregate(record, aggregate);
    }
  }
};

TEST_F(IntegerHistogramAggregationProcedureTest, UpdateAggregateWorksInteger) {
  uint32_t metric_id = kIntegerMetricMetricId;
  auto procedure = GetProcedureFor(metric_id, kIntegerMetricFleetwideHistogramsReportIndex);

  ReportAggregate aggregate;
  const uint32_t kNumEventCodes = 100;
  ASSERT_GE(GetMetricDef(metric_id).event_code_buffer_max(), kNumEventCodes);

  const uint32_t kHourId = 1;
  LogIntegerEvents(kHourId, kNumEventCodes, procedure.get(), &aggregate);

  ASSERT_EQ(aggregate.hourly().by_hour_id_size(), 1);
  ASSERT_EQ(aggregate.hourly().by_hour_id().at(kHourId).by_event_code_size(), kNumEventCodes);
  ASSERT_EQ(aggregate.hourly().by_hour_id().at(kHourId).event_vectors_size(), kNumEventCodes);
}

TEST_F(IntegerHistogramAggregationProcedureTest, UpdateAggregateWorksIntegerHistogram) {
  uint32_t metric_id = kIntegerHistogramMetricMetricId;
  auto procedure =
      GetProcedureFor(metric_id, kIntegerHistogramMetricFleetwideHistogramsReportIndex);

  ReportAggregate aggregate;
  const uint32_t kNumEventCodes = 100;
  ASSERT_GE(GetMetricDef(metric_id).event_code_buffer_max(), kNumEventCodes);

  const uint32_t kHourId = 1;
  LogIntegerHistogramEvents(kHourId, kNumEventCodes, {{1, 10}, {2, 100}, {3, 50}}, procedure.get(),
                            &aggregate);

  ASSERT_EQ(aggregate.hourly().by_hour_id_size(), 1);
  ASSERT_EQ(aggregate.hourly().by_hour_id().at(kHourId).by_event_code_size(), kNumEventCodes);
  ASSERT_EQ(aggregate.hourly().by_hour_id().at(kHourId).event_vectors_size(), kNumEventCodes);
}

TEST_F(IntegerHistogramAggregationProcedureTest, GenerateObservationWorksInteger) {
  uint32_t metric_id = kIntegerMetricMetricId;
  auto procedure = GetProcedureFor(metric_id, kIntegerMetricFleetwideHistogramsReportIndex);

  ReportAggregate aggregate;
  const uint32_t kNumEventCodes = 10;
  ASSERT_GE(GetMetricDef(metric_id).event_code_buffer_max(), kNumEventCodes);

  const uint32_t kEndHourId = 11;
  for (auto hour_id = 1; hour_id <= kEndHourId; hour_id += 2) {
    LogIntegerEvents(hour_id, kNumEventCodes, procedure.get(), &aggregate);
  }

  auto observation_or =
      GenerateObservation(util::TimeInfo::FromHourId(kEndHourId), procedure.get(), &aggregate);
  ASSERT_TRUE(observation_or.ok());

  auto observation = observation_or.ConsumeValueOrDie();

  // Should only generate for kEndHourId
  ASSERT_TRUE(observation);
  ASSERT_EQ(observation->index_histogram().index_histograms_size(), kNumEventCodes);

  for (const auto& value : observation->index_histogram().index_histograms()) {
    ASSERT_EQ(value.bucket_indices(0), value.event_codes(0));
    ASSERT_EQ(value.bucket_counts(0), 1);
  }
  ASSERT_EQ(aggregate.hourly().by_hour_id_size(), 0);
}

TEST_F(IntegerHistogramAggregationProcedureTest, GenerateObservationWorksIntegerHistogram) {
  uint32_t metric_id = kIntegerHistogramMetricMetricId;
  auto procedure =
      GetProcedureFor(metric_id, kIntegerHistogramMetricFleetwideHistogramsReportIndex);

  ReportAggregate aggregate;
  const uint32_t kNumEventCodes = 10;
  ASSERT_GE(GetMetricDef(metric_id).event_code_buffer_max(), kNumEventCodes);

  const uint32_t kEndHourId = 11;
  const std::map<uint32_t, uint64_t> kLoggedHistogram = {{1, 10}, {2, 100}, {3, 50}};
  for (auto hour_id = 1; hour_id <= kEndHourId; hour_id += 2) {
    LogIntegerHistogramEvents(hour_id, kNumEventCodes, kLoggedHistogram, procedure.get(),
                              &aggregate);
  }

  auto observation_or =
      GenerateObservation(util::TimeInfo::FromHourId(kEndHourId), procedure.get(), &aggregate);
  ASSERT_TRUE(observation_or.ok());

  auto observation = observation_or.ConsumeValueOrDie();

  // Should only generate for kEndHourId
  ASSERT_TRUE(observation);
  ASSERT_EQ(observation->index_histogram().index_histograms_size(), kNumEventCodes);

  for (const auto& value : observation->index_histogram().index_histograms()) {
    for (uint32_t i = 0; i < value.bucket_indices_size(); i++) {
      uint64_t expected_value = kLoggedHistogram.at(value.bucket_indices(i));
      ASSERT_EQ(expected_value, value.bucket_counts(i));
    }
  }
  ASSERT_EQ(aggregate.hourly().by_hour_id_size(), 0);
}

}  // namespace cobalt::local_aggregation
