// Copyright 2018 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "src/logger/encoder.h"

#include <memory>
#include <string>

#include "src/algorithms/rappor/rappor_config_helper.h"
#include "src/algorithms/rappor/rappor_encoder.h"
#include "src/lib/client/cpp/buckets_config.h"
#include "src/lib/crypto_util/hash.h"
#include "src/lib/util/datetime_util.h"
#include "src/logger/project_context.h"
#include "src/logging.h"
#include "src/pb/metadata_builder.h"
#include "src/pb/observation.pb.h"
#include "src/registry/packed_event_codes.h"
#include "src/tracing.h"

namespace cobalt::logger {

using ::cobalt::config::IntegerBucketConfig;
using ::cobalt::crypto::byte;
using ::cobalt::crypto::hash::DIGEST_SIZE;
using ::cobalt::rappor::BasicRapporEncoder;
using ::cobalt::rappor::RapporConfigHelper;
using ::cobalt::system_data::ClientSecret;
using ::google::protobuf::RepeatedField;

namespace {
// Populates |*hash_out| with the SHA256 of |component|, unless |component|
// is empty in which case *hash_out is set to the empty string also. An
// empty string indicates that the component_name feature is not being used.
// We expect this to be a common case and in this case there is no point
// in using 32 bytes to represent the empty string. Returns true on success
// and false on failure (unexpected).
bool HashComponentNameIfNotEmpty(const std::string& component, std::string* hash_out) {
  CHECK(hash_out);
  if (component.empty()) {
    hash_out->resize(0);
    return true;
  }
  hash_out->resize(DIGEST_SIZE);
  return cobalt::crypto::hash::Hash(reinterpret_cast<const byte*>(component.data()),
                                    component.size(), reinterpret_cast<byte*>(&hash_out->front()));
}

// Translates a rappor::Status |status| into a logger::Status and prints a debug
// message if |status| is not kOK.
Status TranslateBasicRapporEncoderStatus(MetricRef metric, const ReportDefinition* report,
                                         const rappor::Status& status) {
  switch (status) {
    case rappor::kOK:
      return kOK;
    case rappor::kInvalidConfig:
      LOG(ERROR) << "BasicRapporEncoder returned kInvalidConfig for: Report "
                 << report->report_name() << " for metric " << metric.metric_name()
                 << " in project " << metric.ProjectDebugString() << ".";
      return kInvalidConfig;
    case rappor::kInvalidInput:
      LOG(ERROR) << "BasicRapporEncoder returned kInvalidInput for: Report "
                 << report->report_name() << " for metric " << metric.metric_name()
                 << " in project " << metric.ProjectDebugString() << ".";
      return kInvalidArguments;
  }
}

}  // namespace

Encoder::Encoder(ClientSecret client_secret, MetadataBuilder* metadata_builder)
    : client_secret_(std::move(client_secret)), metadata_builder_(metadata_builder) {}

Encoder::Result Encoder::EncodeBasicRapporObservation(MetricRef metric,
                                                      const ReportDefinition* report,
                                                      uint32_t day_index, uint32_t value_index,
                                                      uint32_t num_categories) const {
  TRACE_DURATION("cobalt_core", "Encoder::EncodeBasicRapporObservation");

  auto result = NewObservationWithMetadata(metric, report, day_index);
  auto* observation = result.observation.get();
  auto* basic_rappor_observation = observation->mutable_basic_rappor();

  rappor::BasicRapporConfig basic_rappor_config;
  basic_rappor_config.prob_rr = RapporConfigHelper::kProbRR;
  basic_rappor_config.categories.set_indexed(num_categories);
  float prob_bit_flip = RapporConfigHelper::ProbBitFlip(*report, metric.FullyQualifiedName());
  basic_rappor_config.prob_0_becomes_1 = prob_bit_flip;
  basic_rappor_config.prob_1_stays_1 = 1.0f - prob_bit_flip;

  // TODO(rudominer) Stop copying the client_secret_ on each Encode*()
  // operation.
  BasicRapporEncoder basic_rappor_encoder(basic_rappor_config, client_secret_);
  ValuePart index_value;
  index_value.set_index_value(value_index);
  result.status = TranslateBasicRapporEncoderStatus(
      metric, report, basic_rappor_encoder.Encode(index_value, basic_rappor_observation));
  return result;
}

Encoder::Result Encoder::EncodeIntegerEventObservation(
    MetricRef metric, const ReportDefinition* report, uint32_t day_index,
    const RepeatedField<uint32_t>& event_codes, const std::string& component, int64_t value) const {
  auto result = NewObservationWithMetadata(metric, report, day_index);
  auto* observation = result.observation.get();
  auto* integer_event_observation = observation->mutable_numeric_event();
  integer_event_observation->set_event_code(config::PackEventCodes(event_codes));
  if (!HashComponentNameIfNotEmpty(component,
                                   integer_event_observation->mutable_component_name_hash())) {
    LOG(ERROR) << "Hashing the component name failed for: Report " << report->report_name()
               << " for metric " << metric.metric_name() << " in project "
               << metric.ProjectDebugString() << ".";
    result.status = kOther;
  }
  integer_event_observation->set_value(value);
  return result;
}

Encoder::Result Encoder::EncodeHistogramObservation(MetricRef metric,
                                                    const ReportDefinition* report,
                                                    uint32_t day_index,
                                                    const RepeatedField<uint32_t>& event_codes,
                                                    const std::string& component,
                                                    HistogramPtr histogram) const {
  auto result = NewObservationWithMetadata(metric, report, day_index);
  auto* observation = result.observation.get();
  auto* histogram_observation = observation->mutable_histogram();
  histogram_observation->set_event_code(config::PackEventCodes(event_codes));
  if (!HashComponentNameIfNotEmpty(component,
                                   histogram_observation->mutable_component_name_hash())) {
    LOG(ERROR) << "Hashing the component name failed for: Report " << report->report_name()
               << " for metric " << metric.metric_name() << " in project "
               << metric.ProjectDebugString() << ".";
    result.status = kOther;
  }
  histogram_observation->mutable_buckets()->Swap(histogram.get());
  return result;
}

Encoder::Result Encoder::EncodeCustomObservation(MetricRef metric, const ReportDefinition* report,
                                                 uint32_t day_index,
                                                 EventValuesPtr event_values) const {
  auto result = NewObservationWithMetadata(metric, report, day_index);
  auto* observation = result.observation.get();
  auto* custom_observation = observation->mutable_custom();
  custom_observation->mutable_values()->swap(*event_values);
  return result;
}

Encoder::Result Encoder::EncodeSerializedCustomObservation(
    MetricRef metric, const ReportDefinition* report, uint32_t day_index,
    std::unique_ptr<std::string> serialized_proto) const {
  auto result = NewObservationWithMetadata(metric, report, day_index);
  auto* observation = result.observation.get();
  auto* custom_observation = observation->mutable_custom();
  custom_observation->mutable_serialized_proto()->swap(*serialized_proto);
  return result;
}

Encoder::Result Encoder::EncodeUniqueActivesObservation(
    MetricRef metric, const ReportDefinition* report, uint32_t day_index, uint32_t event_code,
    bool was_active, const OnDeviceAggregationWindow& aggregation_window) const {
  auto result = NewObservationWithMetadata(metric, report, day_index);
  Encoder::Result basic_rappor_result;
  if (was_active) {
    // Encode a single 1 bit
    basic_rappor_result = EncodeBasicRapporObservation(metric, report, day_index, 0u, 1u);
  } else {
    // Encode a single 0 bit
    basic_rappor_result = EncodeNullBasicRapporObservation(metric, report, day_index, 1u);
  }
  if (basic_rappor_result.status != kOK) {
    result.status = basic_rappor_result.status;
    return result;
  }
  auto* activity_observation = result.observation->mutable_unique_actives();
  activity_observation->mutable_aggregation_window()->CopyFrom(aggregation_window);
  activity_observation->set_event_code(event_code);
  activity_observation->mutable_basic_rappor_obs()->mutable_data()->swap(
      *(basic_rappor_result.observation->mutable_basic_rappor()->mutable_data()));

  return result;
}

Encoder::Result Encoder::EncodePerDeviceNumericObservation(
    MetricRef metric, const ReportDefinition* report, uint32_t day_index,
    const std::string& component, const RepeatedField<uint32_t>& event_codes, int64_t value,
    const OnDeviceAggregationWindow& aggregation_window) const {
  auto result =
      EncodeIntegerEventObservation(metric, report, day_index, event_codes, component, value);
  auto* integer_event_observation = result.observation->release_numeric_event();
  auto* per_device_observation = result.observation->mutable_per_device_numeric();
  per_device_observation->set_allocated_integer_event_obs(integer_event_observation);
  per_device_observation->mutable_aggregation_window()->CopyFrom(aggregation_window);
  return result;
}

Encoder::Result Encoder::EncodePerDeviceHistogramObservation(
    MetricRef metric, const ReportDefinition* report, uint32_t day_index,
    const std::string& component, const RepeatedField<uint32_t>& event_codes, int64_t value,
    const OnDeviceAggregationWindow& aggregation_window) const {
  auto result = NewObservationWithMetadata(metric, report, day_index);
  auto* per_device_histogram_obs = result.observation->mutable_per_device_histogram();
  per_device_histogram_obs->mutable_aggregation_window()->CopyFrom(aggregation_window);
  auto* histogram_observation = per_device_histogram_obs->mutable_histogram();
  histogram_observation->set_event_code(config::PackEventCodes(event_codes));
  if (!HashComponentNameIfNotEmpty(component,
                                   histogram_observation->mutable_component_name_hash())) {
    LOG(ERROR) << "Hashing the component name failed for: Report " << report->report_name()
               << " for metric " << metric.metric_name() << " in project "
               << metric.ProjectDebugString() << ".";
    result.status = kOther;
  }

  auto integer_bucket_config = IntegerBucketConfig::CreateFromProto(report->int_buckets());
  if (integer_bucket_config == nullptr) {
    LOG(ERROR) << "Invalid IntBucketConfig for: Report " << report->report_name() << " for metric "
               << metric.metric_name() << " in project " << metric.ProjectDebugString() << ".";
    result.status = kOther;
    return result;
  }
  auto bucket = histogram_observation->add_buckets();
  bucket->set_index(integer_bucket_config->BucketIndex(value));
  bucket->set_count(1);

  return result;
}

Encoder::Result Encoder::EncodeReportParticipationObservation(MetricRef metric,
                                                              const ReportDefinition* report,
                                                              uint32_t day_index) const {
  auto result = NewObservationWithMetadata(metric, report, day_index);
  auto* observation = result.observation.get();
  observation->mutable_report_participation();
  return result;
}

lib::statusor::StatusOr<std::unique_ptr<Observation>> Encoder::EncodeIntegerObservation(
    const std::vector<std::tuple<std::vector<uint32_t>, int64_t>>& data) {
  auto observation = std::make_unique<Observation>();

  auto* integer_observation = observation->mutable_integer();

  for (const auto& datum : data) {
    auto* value_proto = integer_observation->add_values();
    for (const auto& code : std::get<0>(datum)) {
      value_proto->add_event_codes(code);
    }
    value_proto->set_value(std::get<1>(datum));
  }

  return observation;
}

lib::statusor::StatusOr<std::unique_ptr<Observation>> Encoder::EncodeSumAndCountObservation(
    const std::vector<std::tuple<std::vector<uint32_t>, int64_t, uint32_t>>& data) {
  auto observation = std::make_unique<Observation>();

  SumAndCountObservation* sum_and_count_observation = observation->mutable_sum_and_count();

  for (auto [event_codes, sum, count] : data) {
    SumAndCountObservation_SumAndCount* sum_and_count =
        sum_and_count_observation->add_sums_and_counts();
    for (uint32_t code : event_codes) {
      sum_and_count->add_event_codes(code);
    }
    sum_and_count->set_sum(sum);
    sum_and_count->set_count(count);
  }

  return observation;
}

lib::statusor::StatusOr<std::unique_ptr<Observation>> Encoder::EncodeIndexHistogramObservation(
    const std::vector<
        std::tuple<std::vector<uint32_t>, std::vector<std::tuple<uint32_t, int64_t>>>>& data) {
  auto observation = std::make_unique<Observation>();

  IndexHistogramObservation* index_histogram_observation = observation->mutable_index_histogram();

  for (auto [event_codes, histogram] : data) {
    IndexHistogram* index_histogram = index_histogram_observation->add_index_histograms();
    for (uint32_t code : event_codes) {
      index_histogram->add_event_codes(code);
    }
    for (auto& [index, count] : histogram) {
      index_histogram->add_bucket_indices(index);
      index_histogram->add_bucket_counts(count);
    }
  }

  return observation;
}

lib::statusor::StatusOr<std::unique_ptr<Observation>> Encoder::EncodeStringHistogramObservation(
    const std::vector<std::string>& hashes,
    const std::vector<
        std::tuple<std::vector<uint32_t>, std::vector<std::tuple<uint32_t, int64_t>>>>& data) {
  auto observation = std::make_unique<Observation>();

  StringHistogramObservation* string_histogram_observation =
      observation->mutable_string_histogram();

  for (const auto& hash : hashes) {
    string_histogram_observation->add_string_hashes(hash);
  }

  for (auto [event_codes, histogram] : data) {
    IndexHistogram* index_histogram = string_histogram_observation->add_string_histograms();
    for (uint32_t code : event_codes) {
      index_histogram->add_event_codes(code);
    }
    for (auto& [index, count] : histogram) {
      index_histogram->add_bucket_indices(index);
      index_histogram->add_bucket_counts(count);
    }
  }

  return observation;
}

Encoder::Result Encoder::EncodeNullBasicRapporObservation(MetricRef metric,
                                                          const ReportDefinition* report,
                                                          uint32_t day_index,
                                                          uint32_t num_categories) const {
  auto result = NewObservationWithMetadata(metric, report, day_index);
  auto* observation = result.observation.get();
  auto* basic_rappor_observation = observation->mutable_basic_rappor();

  rappor::BasicRapporConfig basic_rappor_config;
  basic_rappor_config.prob_rr = RapporConfigHelper::kProbRR;
  basic_rappor_config.categories.set_indexed(num_categories);
  float prob_bit_flip = RapporConfigHelper::ProbBitFlip(*report, metric.FullyQualifiedName());
  basic_rappor_config.prob_0_becomes_1 = prob_bit_flip;
  basic_rappor_config.prob_1_stays_1 = 1.0f - prob_bit_flip;

  // TODO(rudominer) Stop copying the client_secret_ on each Encode*()
  // operation.
  BasicRapporEncoder basic_rappor_encoder(basic_rappor_config, client_secret_);
  result.status = TranslateBasicRapporEncoderStatus(
      metric, report, basic_rappor_encoder.EncodeNullObservation(basic_rappor_observation));
  return result;
}

Encoder::Result Encoder::NewObservationWithMetadata(MetricRef metric,
                                                    const ReportDefinition* report,
                                                    uint32_t day_index) const {
  Result result;
  result.status = kOK;
  result.observation = std::make_unique<Observation>();
  result.metadata =
      metadata_builder_->Build(metric, *report, day_index, util::DayIndexToHourId(day_index));

  return result;
}

}  // namespace cobalt::logger
