#include "src/logger/privacy_encoder.h"

#include "src/algorithms/privacy/numeric_encoding.h"
#include "src/algorithms/privacy/rappor.h"
#include "src/algorithms/random/random.h"
#include "src/lib/statusor/status_macros.h"
#include "src/logger/event_vector_index.h"
#include "src/pb/observation.pb.h"

namespace cobalt::logger {

PrivacyEncoder::PrivacyEncoder(std::unique_ptr<SecureBitGeneratorInterface<uint32_t>> secure_gen,
                               std::unique_ptr<BitGeneratorInterface<uint32_t>> gen)
    : secure_gen_(std::move(secure_gen)), gen_(std::move(gen)) {}

std::unique_ptr<PrivacyEncoder> PrivacyEncoder::MakeSecurePrivacyEncoder() {
  return std::make_unique<PrivacyEncoder>(std::make_unique<SecureRandomNumberGenerator>(),
                                          std::make_unique<RandomNumberGenerator>());
}

lib::statusor::StatusOr<std::vector<std::unique_ptr<Observation>>>
PrivacyEncoder::MaybeMakePrivateObservations(std::unique_ptr<Observation> observation,
                                             const MetricDefinition &metric_def,
                                             const ReportDefinition &report_def) {
  if (report_def.privacy_level() > ReportDefinition::NO_ADDED_PRIVACY) {
    return MakePrivateObservations(observation.get(), metric_def, report_def);
  }

  std::vector<std::unique_ptr<Observation>> observations;
  observations.push_back(std::move(observation));
  return observations;
}

lib::statusor::StatusOr<std::vector<std::unique_ptr<Observation>>>
PrivacyEncoder::MakePrivateObservations(const Observation *observation,
                                        const MetricDefinition &metric_def,
                                        const ReportDefinition &report_def) {
  // Check that the report has a nontrivial privacy level and return if not.
  if (report_def.privacy_level() <= ReportDefinition::NO_ADDED_PRIVACY) {
    return util::Status(util::INVALID_ARGUMENT, "report has no added privacy.");
  }
  // If |observation| is null, then proceed with an empty list of indices.
  std::vector<uint64_t> indices;
  if (observation) {
    CB_ASSIGN_OR_RETURN(indices, PrepareIndexVector(*observation, metric_def, report_def));
  }
  CB_ASSIGN_OR_RETURN(std::vector<uint64_t> indices_with_noise,
                      AddNoise(indices, metric_def, report_def));

  return ObservationsFromIndices(indices_with_noise);
}

lib::statusor::StatusOr<uint64_t> PrivacyEncoder::MaxIndexForReport(
    const MetricDefinition &metric_def, const ReportDefinition &report_def) {
  switch (report_def.report_type()) {
    case ReportDefinition::UNIQUE_DEVICE_COUNTS: {
      return GetNumEventVectors(metric_def.metric_dimensions()) - 1;
    }
    case ReportDefinition::FLEETWIDE_OCCURRENCE_COUNTS:
    case ReportDefinition::HOURLY_VALUE_NUMERIC_STATS:
    case ReportDefinition::HOURLY_VALUE_HISTOGRAMS:
    case ReportDefinition::UNIQUE_DEVICE_NUMERIC_STATS:
    case ReportDefinition::UNIQUE_DEVICE_HISTOGRAMS: {
      return (GetNumEventVectors(metric_def.metric_dimensions()) * report_def.num_index_points()) -
             1;
    }
    case ReportDefinition::FLEETWIDE_MEANS: {
      return 2 * (GetNumEventVectors(metric_def.metric_dimensions()) *
                  report_def.num_index_points()) -
             1;
    }

    default:
      return util::Status(util::UNIMPLEMENTED, "this is not yet implemented");
  }
}

lib::statusor::StatusOr<std::vector<uint64_t>> PrivacyEncoder::PrepareIndexVector(
    const Observation &observation, const MetricDefinition &metric_def,
    const ReportDefinition &report_def) {
  std::vector<uint64_t> indices;
  switch (report_def.report_type()) {
    case ReportDefinition::UNIQUE_DEVICE_COUNTS: {
      CB_ASSIGN_OR_RETURN(indices, PrepareIndexVectorForUniqueDeviceCount(observation, metric_def));
      break;
    }
    case ReportDefinition::FLEETWIDE_OCCURRENCE_COUNTS:
    case ReportDefinition::HOURLY_VALUE_NUMERIC_STATS:
    case ReportDefinition::HOURLY_VALUE_HISTOGRAMS:
    case ReportDefinition::UNIQUE_DEVICE_NUMERIC_STATS:
    case ReportDefinition::UNIQUE_DEVICE_HISTOGRAMS: {
      CB_ASSIGN_OR_RETURN(indices, PrepareIndexVectorForPerDeviceIntegerReport(
                                       observation, metric_def, report_def));
      break;
    }
    case ReportDefinition::FLEETWIDE_MEANS: {
      CB_ASSIGN_OR_RETURN(
          indices, PrepareIndexVectorForFleetwideMeansReport(observation, metric_def, report_def));
      break;
    }

    default:
      return util::Status(util::UNIMPLEMENTED, "this is not yet implemented");
  }
  return indices;
}

std::vector<std::unique_ptr<Observation>> PrivacyEncoder::ObservationsFromIndices(
    const std::vector<uint64_t> &indices) {
  std::vector<std::unique_ptr<Observation>> observations;
  for (uint64_t index : indices) {
    auto observation = std::make_unique<Observation>();
    auto *private_index = observation->mutable_private_index();
    private_index->set_index(index);
    observations.push_back(std::move(observation));
  }
  auto observation = std::make_unique<Observation>();
  observation->mutable_report_participation();
  observations.push_back(std::move(observation));
  return observations;
}

lib::statusor::StatusOr<std::vector<uint64_t>> PrivacyEncoder::AddNoise(
    const std::vector<uint64_t> &indices, const MetricDefinition &metric_def,
    const ReportDefinition &report_def) {
  CB_ASSIGN_OR_RETURN(uint64_t max_index, MaxIndexForReport(metric_def, report_def));

  double p = report_def.prob_bit_flip();
  if (p < 0 || p > 1) {
    return util::Status(util::INVALID_ARGUMENT, "prob_bit_flip is not between 0 and 1");
  }

  for (auto index : indices) {
    if (index > max_index) {
      return util::Status(util::INVALID_ARGUMENT, "index is outside the range of valid indices.");
    }
  }

  return ApplyRapporNoise(indices, max_index, p, secure_gen_.get());
}

lib::statusor::StatusOr<std::vector<uint64_t>>
PrivacyEncoder::PrepareIndexVectorForUniqueDeviceCount(const Observation &observation,
                                                       const MetricDefinition &metric_def) {
  if (!observation.has_integer()) {
    return util::Status(util::INVALID_ARGUMENT, "observation type is not IntegerObservation.");
  }
  std::vector<uint64_t> occurred_indices;
  for (const auto &value : observation.integer().values()) {
    if (value.value() == 1) {
      std::vector<uint32_t> event_codes(value.event_codes().begin(), value.event_codes().end());
      CB_ASSIGN_OR_RETURN(auto index, EventVectorToIndex(event_codes, metric_def));
      occurred_indices.push_back(index);
    }
  }
  return occurred_indices;
}

lib::statusor::StatusOr<std::vector<uint64_t>>
PrivacyEncoder::PrepareIndexVectorForPerDeviceIntegerReport(const Observation &observation,
                                                            const MetricDefinition &metric_def,
                                                            const ReportDefinition &report_def) {
  if (!observation.has_integer()) {
    return util::Status(util::INVALID_ARGUMENT, "observation type is not IntegerObservation.");
  }

  std::vector<uint64_t> occurred_indices;
  for (const auto &value : observation.integer().values()) {
    std::vector<uint32_t> event_codes(value.event_codes().begin(), value.event_codes().end());
    CB_ASSIGN_OR_RETURN(auto event_vector_index, EventVectorToIndex(event_codes, metric_def));
    uint64_t value_index =
        IntegerToIndex(value.value(), report_def.min_value(), report_def.max_value(),
                       report_def.num_index_points(), gen_.get());
    occurred_indices.push_back(ValueAndEventVectorIndicesToIndex(
        value_index, event_vector_index, GetNumEventVectors(metric_def.metric_dimensions()) - 1));
  }

  return occurred_indices;
}

lib::statusor::StatusOr<std::vector<uint64_t>>
PrivacyEncoder::PrepareIndexVectorForFleetwideMeansReport(const Observation &observation,
                                                          const MetricDefinition &metric_def,
                                                          const ReportDefinition &report_def) {
  std::vector<uint64_t> occurred_indices;
  if (!observation.has_sum_and_count()) {
    return util::Status(util::INVALID_ARGUMENT, "observation type is not SumAndCountObservation.");
  }

  for (const auto &value : observation.sum_and_count().sums_and_counts()) {
    std::vector<uint32_t> event_codes(value.event_codes().begin(), value.event_codes().end());
    CB_ASSIGN_OR_RETURN(auto event_vector_index, EventVectorToIndex(event_codes, metric_def));

    uint64_t sum_index = IntegerToIndex(value.sum(), report_def.min_value(), report_def.max_value(),
                                        report_def.num_index_points(), gen_.get());
    uint64_t count_index = CountToIndex(value.count(), report_def.max_count(),
                                        report_def.num_index_points(), gen_.get());
    occurred_indices.push_back(ValueAndEventVectorIndicesToIndex(
        sum_index, event_vector_index, GetNumEventVectors(metric_def.metric_dimensions()) - 1));
    occurred_indices.push_back(ValueAndEventVectorIndicesToIndex(
        count_index, event_vector_index, GetNumEventVectors(metric_def.metric_dimensions()) - 1));
  }

  return occurred_indices;
}

}  // namespace cobalt::logger
