// Copyright 2020 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "src/local_aggregation_1_1/observation_generator.h"

#include <memory>

#include <gtest/gtest.h>

#include "absl/strings/escaping.h"
#include "src/algorithms/random/test_secure_random.h"
#include "src/lib/util/clock.h"
#include "src/lib/util/datetime_util.h"
#include "src/lib/util/testing/test_with_files.h"
#include "src/local_aggregation_1_1/aggregation_procedures/aggregation_procedure.h"
#include "src/local_aggregation_1_1/local_aggregate_storage/immediate_local_aggregate_storage.h"
#include "src/local_aggregation_1_1/local_aggregate_storage/local_aggregate_storage.h"
#include "src/local_aggregation_1_1/local_aggregation.pb.h"
#include "src/local_aggregation_1_1/testing/test_registry.cb.h"
#include "src/logger/observation_writer.h"
#include "src/logger/privacy_encoder.h"
#include "src/logger/project_context_factory.h"
#include "src/observation_store/observation_store.h"
#include "src/observation_store/observation_store_internal.pb.h"
#include "src/pb/metadata_builder.h"
#include "src/pb/observation_batch.pb.h"
#include "src/system_data/client_secret.h"
#include "src/system_data/fake_system_data.h"

namespace cobalt::local_aggregation {

using TimeInfo = util::TimeInfo;
using MetricAggregateRef = LocalAggregateStorage::MetricAggregateRef;

// The expected size in bytes of an Observation's |random_id| field.
constexpr size_t kRandomIdSize = 8u;

namespace {

std::unique_ptr<CobaltRegistry> GetRegistry() {
  std::string bytes;
  if (!absl::Base64Unescape(kCobaltRegistryBase64, &bytes)) {
    LOG(ERROR) << "Unable to decode Base64 String";
    return nullptr;
  }

  auto registry = std::make_unique<CobaltRegistry>();
  if (!registry->ParseFromString(bytes)) {
    LOG(ERROR) << "Unable to parse registry from bytes";
    return nullptr;
  }

  return registry;
}

class FakePrivacyEncoder : public logger::PrivacyEncoder {
 public:
  explicit FakePrivacyEncoder(bool return_private_observations, int num_private_observations = 0)
      : PrivacyEncoder(std::make_unique<TestSecureRandomNumberGenerator>(0),
                       std::make_unique<RandomNumberGenerator>(0)),
        return_private_observations_(return_private_observations),
        num_private_observations_(num_private_observations) {}

  lib::statusor::StatusOr<std::vector<std::unique_ptr<Observation>>> MaybeMakePrivateObservations(
      std::unique_ptr<Observation> observation, const MetricDefinition& /*metric_def*/,
      const ReportDefinition& /*report_def*/) override {
    std::vector<std::unique_ptr<Observation>> observations;

    if (!return_private_observations_) {
      observations.push_back(std::move(observation));
      return observations;
    }

    for (int i = 0; i < num_private_observations_; i++) {
      auto observation = std::make_unique<Observation>();
      observations.push_back(std::move(observation));
    }

    return observations;
  }

 private:
  bool return_private_observations_;
  int num_private_observations_;
};

}  // namespace

class ObservationGeneratorTest : public util::testing::TestWithFiles {
 public:
  void SetUp() override {
    MakeTestFolder();
    metadata_builder_ =
        std::make_unique<MetadataBuilder>(&system_data_, system_data_cache_path(), fs());
    project_context_factory_ = std::make_unique<logger::ProjectContextFactory>(GetRegistry());
    aggregate_storage_ =
        LocalAggregateStorage::New(LocalAggregateStorage::StorageStrategy::Immediate, test_folder(),
                                   fs(), project_context_factory_.get());
  }

  MetricAggregateRef GetMetricAggregate(uint32_t metric_id) {
    lib::statusor::StatusOr<MetricAggregateRef> metric_aggregate_or =
        aggregate_storage_->GetMetricAggregate(kCustomerId, kProjectId, metric_id);
    EXPECT_TRUE(metric_aggregate_or.ok());

    return metric_aggregate_or.ConsumeValueOrDie();
  }

  void ConstructObservationGenerator(const logger::ObservationWriter* observation_writer,
                                     std::unique_ptr<FakePrivacyEncoder> privacy_encoder) {
    observation_generator_ = std::make_unique<ObservationGenerator>(
        aggregate_storage_.get(), project_context_factory_.get(), metadata_builder_.get(),
        observation_writer, std::move(privacy_encoder));
  }

  void TearDown() override { observation_generator_->ShutDown(); }

  util::Status GenerateObservationsOnce(util::TimeInfo utc, util::TimeInfo local) {
    return observation_generator_->GenerateObservationsOnce(utc, local);
  }

 private:
  system_data::FakeSystemData system_data_;

  std::unique_ptr<MetadataBuilder> metadata_builder_;
  std::unique_ptr<logger::ProjectContextFactory> project_context_factory_;
  std::unique_ptr<LocalAggregateStorage> aggregate_storage_;
  std::unique_ptr<ObservationGenerator> observation_generator_;
};

class TestObservationStoreWriter : public observation_store::ObservationStoreWriterInterface {
 public:
  explicit TestObservationStoreWriter(
      std::function<void(std::unique_ptr<observation_store::StoredObservation>,
                         std::unique_ptr<ObservationMetadata>)>
          watcher)
      : watcher_(std::move(watcher)) {}

  StoreStatus StoreObservation(std::unique_ptr<observation_store::StoredObservation> observation,
                               std::unique_ptr<ObservationMetadata> metadata) override {
    watcher_(std::move(observation), std::move(metadata));
    return StoreStatus::kOk;
  }

 private:
  std::function<void(std::unique_ptr<observation_store::StoredObservation>,
                     std::unique_ptr<ObservationMetadata>)>
      watcher_;
};

TEST_F(ObservationGeneratorTest, GeneratesHourlyObservationsAsExpected) {
  const uint32_t kMaxHourId = 101;

  {
    MetricAggregateRef aggregate = GetMetricAggregate(kOccurrenceMetricMetricId);
    ReportAggregate* report =
        &(*aggregate.aggregate()
               ->mutable_by_report_id())[kOccurrenceMetricFleetwideOccurrenceCountsReportReportId];
    for (uint32_t i = 1; i <= kMaxHourId; i += 2) {
      (*report->mutable_hourly()->mutable_by_hour_id())[i]
          .add_by_event_code()
          ->mutable_data()
          ->set_count(i * 100);
    }
    ASSERT_TRUE(aggregate.Save().ok());
  }

  std::unique_ptr<ObservationMetadata> last_metadata;
  std::unique_ptr<Observation> last_observation;
  TestObservationStoreWriter test_writer(
      [&last_metadata, &last_observation](
          std::unique_ptr<observation_store::StoredObservation> observation,
          std::unique_ptr<ObservationMetadata> metadata) {
        if (metadata->report_id() == kOccurrenceMetricFleetwideOccurrenceCountsReportReportId) {
          last_metadata = std::move(metadata);
          if (observation->has_unencrypted()) {
            last_observation = std::unique_ptr<Observation>(observation->release_unencrypted());
          }
        }
      });

  logger::ObservationWriter observation_writer(&test_writer, nullptr);
  ConstructObservationGenerator(&observation_writer, std::make_unique<FakePrivacyEncoder>(false));

  for (uint32_t i = 1; i <= kMaxHourId; i += 4) {
    GenerateObservationsOnce(TimeInfo::FromHourId(i), TimeInfo::FromHourId(i));

    EXPECT_EQ(last_metadata->customer_id(), kCustomerId);
    EXPECT_EQ(last_metadata->day_index(), util::HourIdToDayIndex(i)) << "Error for i: " << i;
    ASSERT_TRUE(last_observation);
    EXPECT_EQ(last_observation->random_id().size(), kRandomIdSize);
    ASSERT_TRUE(last_observation->has_integer());
    ASSERT_GT(last_observation->integer().values_size(), 0);
    EXPECT_EQ(last_observation->integer().values(0).value(), i * 100);
  }
}

TEST_F(ObservationGeneratorTest, GeneratesDailyObservationsAsExpected) {
  const uint32_t kMaxDayIndex = 5;

  {
    MetricAggregateRef aggregate = GetMetricAggregate(kOccurrenceMetricMetricId);
    ReportAggregate* report =
        &(*aggregate.aggregate()
               ->mutable_by_report_id())[kOccurrenceMetricUniqueDeviceCountsReport1DayReportId];
    for (uint32_t i = 1; i <= kMaxDayIndex; i += 1) {
      (*report->mutable_daily()->mutable_by_day_index())[i]
          .add_by_event_code()
          ->mutable_data()
          ->set_at_least_once(true);
    }
    ASSERT_TRUE(aggregate.Save().ok());
  }

  std::unique_ptr<ObservationMetadata> last_metadata;
  std::unique_ptr<Observation> last_observation;
  TestObservationStoreWriter test_writer(
      [&last_metadata, &last_observation](
          std::unique_ptr<observation_store::StoredObservation> observation,
          std::unique_ptr<ObservationMetadata> metadata) {
        if (metadata->report_id() == kOccurrenceMetricUniqueDeviceCountsReport1DayReportId) {
          last_metadata = std::move(metadata);
          if (observation->has_unencrypted()) {
            last_observation = std::unique_ptr<Observation>(observation->release_unencrypted());
          }
        }
      });

  logger::ObservationWriter observation_writer(&test_writer, nullptr);
  ConstructObservationGenerator(&observation_writer, std::make_unique<FakePrivacyEncoder>(false));

  for (uint32_t i = 1; i <= kMaxDayIndex; i += 1) {
    GenerateObservationsOnce(TimeInfo::FromDayIndex(i), TimeInfo::FromDayIndex(i));

    EXPECT_TRUE(last_metadata);
    EXPECT_EQ(last_metadata->customer_id(), kCustomerId);
    EXPECT_EQ(last_metadata->day_index(), i);
    ASSERT_TRUE(last_observation);
    EXPECT_EQ(last_observation->random_id().size(), kRandomIdSize);
    ASSERT_TRUE(last_observation->has_integer());
    ASSERT_EQ(last_observation->integer().values_size(), 1);
    EXPECT_EQ(last_observation->integer().values(0).value(), 1);
  }
}

TEST_F(ObservationGeneratorTest, GeneratesPrivateObservations) {
  uint32_t kMaxHourId = 101;
  int kNumPrivateObs = 2;
  {
    MetricAggregateRef aggregate = GetMetricAggregate(kOccurrenceMetricMetricId);
    auto* report =
        &(*aggregate.aggregate()
               ->mutable_by_report_id())[kOccurrenceMetricFleetwideOccurrenceCountsReportReportId];
    for (uint32_t i = 1; i <= kMaxHourId; i += 2) {
      (*report->mutable_hourly()->mutable_by_hour_id())[i]
          .add_by_event_code()
          ->mutable_data()
          ->set_count(i * 100);
    }
    ASSERT_TRUE(aggregate.Save().ok());
  }

  std::vector<Observation> observations;
  observations.reserve(kNumPrivateObs);
  TestObservationStoreWriter test_writer(
      [&observations](std::unique_ptr<observation_store::StoredObservation> observation,
                      std::unique_ptr<ObservationMetadata> metadata) {
        if (metadata->report_id() == kOccurrenceMetricFleetwideOccurrenceCountsReportReportId) {
          if (observation->has_unencrypted()) {
            observations.push_back(observation->unencrypted());
          }
        }
      });

  logger::ObservationWriter observation_writer(&test_writer, nullptr);
  ConstructObservationGenerator(&observation_writer,
                                std::make_unique<FakePrivacyEncoder>(true, kNumPrivateObs));
  GenerateObservationsOnce(TimeInfo::FromHourId(1), TimeInfo::FromHourId(1));

  EXPECT_EQ(observations.size(), kNumPrivateObs);
  for (const auto& obs : observations) {
    EXPECT_EQ(obs.random_id().size(), kRandomIdSize);
  }
}

}  // namespace cobalt::local_aggregation
