#include "src/logger/privacy_encoder.h"

#include "src/algorithms/privacy/numeric_encoding.h"
#include "src/algorithms/privacy/rappor.h"
#include "src/algorithms/random/random.h"
#include "src/lib/statusor/status_macros.h"
#include "src/logger/event_vector_index.h"
#include "src/pb/observation.pb.h"
#include "src/registry/buckets_config.h"

namespace cobalt::logger {
namespace {

// Returns the number of histogram buckets associated to an IntegerBuckets, including the underflow
// and overflow buckets.
lib::statusor::StatusOr<uint32_t> GetNumIntegerBuckets(const IntegerBuckets &int_buckets) {
  uint32_t num_buckets = 2;
  switch (int_buckets.buckets_case()) {
    case IntegerBuckets::kExponential:
      num_buckets += int_buckets.exponential().num_buckets();
      break;
    case IntegerBuckets::kLinear:
      num_buckets += int_buckets.linear().num_buckets();
      break;
    default:
      return util::Status(util::INVALID_ARGUMENT, "invalid IntegerBuckets type.");
  }
  return num_buckets;
}

}  // namespace

PrivacyEncoder::PrivacyEncoder(std::unique_ptr<SecureBitGeneratorInterface<uint32_t>> secure_gen,
                               std::unique_ptr<BitGeneratorInterface<uint32_t>> gen)
    : secure_gen_(std::move(secure_gen)), gen_(std::move(gen)) {}

std::unique_ptr<PrivacyEncoder> PrivacyEncoder::MakeSecurePrivacyEncoder() {
  return std::make_unique<PrivacyEncoder>(std::make_unique<SecureRandomNumberGenerator>(),
                                          std::make_unique<RandomNumberGenerator>());
}

lib::statusor::StatusOr<std::vector<std::unique_ptr<Observation>>>
PrivacyEncoder::MaybeMakePrivateObservations(std::unique_ptr<Observation> observation,
                                             const MetricDefinition &metric_def,
                                             const ReportDefinition &report_def) {
  if (report_def.privacy_level() > ReportDefinition::NO_ADDED_PRIVACY) {
    return MakePrivateObservations(observation.get(), metric_def, report_def);
  }

  std::vector<std::unique_ptr<Observation>> observations;
  observations.push_back(std::move(observation));
  return observations;
}

lib::statusor::StatusOr<std::vector<std::unique_ptr<Observation>>>
PrivacyEncoder::MakePrivateObservations(const Observation *observation,
                                        const MetricDefinition &metric_def,
                                        const ReportDefinition &report_def) {
  // Check that the report has a nontrivial privacy level and return if not.
  if (report_def.privacy_level() <= ReportDefinition::NO_ADDED_PRIVACY) {
    return util::Status(util::INVALID_ARGUMENT, "report has no added privacy.");
  }
  // If |observation| is null, then proceed with an empty list of indices.
  std::vector<uint64_t> indices;
  if (observation) {
    CB_ASSIGN_OR_RETURN(indices, PrepareIndexVector(*observation, metric_def, report_def));
  }
  CB_ASSIGN_OR_RETURN(std::vector<uint64_t> indices_with_noise,
                      AddNoise(indices, metric_def, report_def));

  return ObservationsFromIndices(indices_with_noise);
}

lib::statusor::StatusOr<uint64_t> PrivacyEncoder::MaxIndexForReport(
    const MetricDefinition &metric_def, const ReportDefinition &report_def) {
  switch (report_def.report_type()) {
    case ReportDefinition::UNIQUE_DEVICE_COUNTS: {
      return GetNumEventVectors(metric_def.metric_dimensions()) - 1;
    }
    case ReportDefinition::FLEETWIDE_OCCURRENCE_COUNTS:
    case ReportDefinition::HOURLY_VALUE_NUMERIC_STATS:
    case ReportDefinition::UNIQUE_DEVICE_NUMERIC_STATS: {
      return (GetNumEventVectors(metric_def.metric_dimensions()) * report_def.num_index_points()) -
             1;
    }
    case ReportDefinition::FLEETWIDE_MEANS: {
      return 2 * (GetNumEventVectors(metric_def.metric_dimensions()) *
                  report_def.num_index_points()) -
             1;
    }
    case ReportDefinition::HOURLY_VALUE_HISTOGRAMS:
    case ReportDefinition::UNIQUE_DEVICE_HISTOGRAMS: {
      lib::statusor::StatusOr<uint32_t> num_buckets =
          GetNumIntegerBuckets(report_def.int_buckets());
      if (!num_buckets.ok()) {
        return num_buckets;
      }
      return (GetNumEventVectors(metric_def.metric_dimensions()) * num_buckets.ValueOrDie()) - 1;
    }
    case ReportDefinition::FLEETWIDE_HISTOGRAMS: {
      lib::statusor::StatusOr<uint32_t> num_buckets =
          GetNumIntegerBuckets(report_def.int_buckets());
      switch (metric_def.metric_type()) {
        case MetricDefinition::INTEGER: {
          num_buckets = GetNumIntegerBuckets(report_def.int_buckets());
          break;
        }
        case MetricDefinition::INTEGER_HISTOGRAM: {
          num_buckets = GetNumIntegerBuckets(metric_def.int_buckets());
          break;
        }
        default:
          return util::Status(util::INVALID_ARGUMENT,
                              "invalid metric type with FLEETWIDE_HISTOGRAMS report.");
      }
      return (GetNumEventVectors(metric_def.metric_dimensions()) * num_buckets.ValueOrDie() *
              report_def.num_index_points()) -
             1;
    }

    default:
      return util::Status(util::UNIMPLEMENTED, "this is not yet implemented");
  }
}

lib::statusor::StatusOr<std::vector<uint64_t>> PrivacyEncoder::PrepareIndexVector(
    const Observation &observation, const MetricDefinition &metric_def,
    const ReportDefinition &report_def) {
  std::vector<uint64_t> indices;
  switch (report_def.report_type()) {
    case ReportDefinition::UNIQUE_DEVICE_COUNTS: {
      CB_ASSIGN_OR_RETURN(indices, PrepareIndexVectorForUniqueDeviceCount(observation, metric_def));
      break;
    }
    case ReportDefinition::FLEETWIDE_OCCURRENCE_COUNTS:
    case ReportDefinition::HOURLY_VALUE_NUMERIC_STATS:
    case ReportDefinition::UNIQUE_DEVICE_NUMERIC_STATS: {
      CB_ASSIGN_OR_RETURN(indices, PrepareIndexVectorForPerDeviceIntegerReport(
                                       observation, metric_def, report_def));
      break;
    }
    case ReportDefinition::FLEETWIDE_MEANS: {
      CB_ASSIGN_OR_RETURN(
          indices, PrepareIndexVectorForFleetwideMeansReport(observation, metric_def, report_def));
      break;
    }
    case ReportDefinition::HOURLY_VALUE_HISTOGRAMS:
    case ReportDefinition::UNIQUE_DEVICE_HISTOGRAMS: {
      CB_ASSIGN_OR_RETURN(indices, PrepareIndexVectorForPerDeviceHistogramsReport(
                                       observation, metric_def, report_def));
      break;
    }
    case ReportDefinition::FLEETWIDE_HISTOGRAMS: {
      CB_ASSIGN_OR_RETURN(indices, PrepareIndexVectorForFleetwideHistogramsReport(
                                       observation, metric_def, report_def));
      break;
    }

    default:
      return util::Status(util::UNIMPLEMENTED, "this is not yet implemented");
  }
  return indices;
}

std::vector<std::unique_ptr<Observation>> PrivacyEncoder::ObservationsFromIndices(
    const std::vector<uint64_t> &indices) {
  std::vector<std::unique_ptr<Observation>> observations;
  for (uint64_t index : indices) {
    auto observation = std::make_unique<Observation>();
    auto *private_index = observation->mutable_private_index();
    private_index->set_index(index);
    observations.push_back(std::move(observation));
  }
  auto observation = std::make_unique<Observation>();
  observation->mutable_report_participation();
  observations.push_back(std::move(observation));
  return observations;
}

lib::statusor::StatusOr<std::vector<uint64_t>> PrivacyEncoder::AddNoise(
    const std::vector<uint64_t> &indices, const MetricDefinition &metric_def,
    const ReportDefinition &report_def) {
  CB_ASSIGN_OR_RETURN(uint64_t max_index, MaxIndexForReport(metric_def, report_def));

  double p = report_def.prob_bit_flip();
  if (p < 0 || p > 1) {
    return util::Status(util::INVALID_ARGUMENT, "prob_bit_flip is not between 0 and 1");
  }

  for (auto index : indices) {
    if (index > max_index) {
      return util::Status(util::INVALID_ARGUMENT, "index is outside the range of valid indices.");
    }
  }

  return ApplyRapporNoise(indices, max_index, p, secure_gen_.get());
}

lib::statusor::StatusOr<std::vector<uint64_t>>
PrivacyEncoder::PrepareIndexVectorForUniqueDeviceCount(const Observation &observation,
                                                       const MetricDefinition &metric_def) {
  if (!observation.has_integer()) {
    return util::Status(util::INVALID_ARGUMENT, "observation type is not IntegerObservation.");
  }
  std::vector<uint64_t> occurred_indices;
  for (const auto &value : observation.integer().values()) {
    if (value.value() == 1) {
      std::vector<uint32_t> event_codes(value.event_codes().begin(), value.event_codes().end());
      CB_ASSIGN_OR_RETURN(auto index, EventVectorToIndex(event_codes, metric_def));
      occurred_indices.push_back(index);
    }
  }
  return occurred_indices;
}

lib::statusor::StatusOr<std::vector<uint64_t>>
PrivacyEncoder::PrepareIndexVectorForPerDeviceIntegerReport(const Observation &observation,
                                                            const MetricDefinition &metric_def,
                                                            const ReportDefinition &report_def) {
  if (!observation.has_integer()) {
    return util::Status(util::INVALID_ARGUMENT, "observation type is not IntegerObservation.");
  }

  std::vector<uint64_t> occurred_indices;
  for (const auto &value : observation.integer().values()) {
    std::vector<uint32_t> event_codes(value.event_codes().begin(), value.event_codes().end());
    CB_ASSIGN_OR_RETURN(auto event_vector_index, EventVectorToIndex(event_codes, metric_def));
    uint64_t value_index =
        IntegerToIndex(value.value(), report_def.min_value(), report_def.max_value(),
                       report_def.num_index_points(), gen_.get());
    occurred_indices.push_back(ValueAndEventVectorIndicesToIndex(
        value_index, event_vector_index, GetNumEventVectors(metric_def.metric_dimensions()) - 1));
  }

  return occurred_indices;
}

lib::statusor::StatusOr<std::vector<uint64_t>>
PrivacyEncoder::PrepareIndexVectorForFleetwideMeansReport(const Observation &observation,
                                                          const MetricDefinition &metric_def,
                                                          const ReportDefinition &report_def) {
  std::vector<uint64_t> occurred_indices;
  if (!observation.has_sum_and_count()) {
    return util::Status(util::INVALID_ARGUMENT, "observation type is not SumAndCountObservation.");
  }

  for (const auto &value : observation.sum_and_count().sums_and_counts()) {
    std::vector<uint32_t> event_codes(value.event_codes().begin(), value.event_codes().end());
    CB_ASSIGN_OR_RETURN(auto event_vector_index, EventVectorToIndex(event_codes, metric_def));

    uint64_t sum_index = IntegerToIndex(value.sum(), report_def.min_value(), report_def.max_value(),
                                        report_def.num_index_points(), gen_.get());
    uint64_t count_index = CountToIndex(value.count(), report_def.max_count(),
                                        report_def.num_index_points(), gen_.get());
    occurred_indices.push_back(ValueAndEventVectorIndicesToIndex(
        sum_index, event_vector_index, GetNumEventVectors(metric_def.metric_dimensions()) - 1));
    occurred_indices.push_back(ValueAndEventVectorIndicesToIndex(
        count_index, event_vector_index, GetNumEventVectors(metric_def.metric_dimensions()) - 1));
  }

  return occurred_indices;
}

lib::statusor::StatusOr<std::vector<uint64_t>>
PrivacyEncoder::PrepareIndexVectorForPerDeviceHistogramsReport(const Observation &observation,
                                                               const MetricDefinition &metric_def,
                                                               const ReportDefinition &report_def) {
  std::vector<uint64_t> occurred_indices;
  if (!observation.has_integer()) {
    return util::Status(util::INVALID_ARGUMENT, "observation type is not IntegerObservation.");
  }

  std::unique_ptr<config::IntegerBucketConfig> integer_buckets =
      config::IntegerBucketConfig::CreateFromProto(report_def.int_buckets());

  for (const auto &value : observation.integer().values()) {
    std::vector<uint32_t> event_codes(value.event_codes().begin(), value.event_codes().end());
    CB_ASSIGN_OR_RETURN(auto event_vector_index, EventVectorToIndex(event_codes, metric_def));

    uint32_t bucket_index = integer_buckets->BucketIndex(value.value());
    occurred_indices.push_back(ValueAndEventVectorIndicesToIndex(
        bucket_index, event_vector_index, GetNumEventVectors(metric_def.metric_dimensions()) - 1));
  }

  return occurred_indices;
}

lib::statusor::StatusOr<std::vector<uint64_t>>
PrivacyEncoder::PrepareIndexVectorForFleetwideHistogramsReport(const Observation &observation,
                                                               const MetricDefinition &metric_def,
                                                               const ReportDefinition &report_def) {
  std::vector<uint64_t> occurred_indices;
  if (!observation.has_index_histogram()) {
    return util::Status(util::INVALID_ARGUMENT,
                        "observation type is not IndexHistogramObservation.");
  }

  for (const auto &histogram : observation.index_histogram().index_histograms()) {
    std::vector<uint32_t> event_codes(histogram.event_codes().begin(),
                                      histogram.event_codes().end());
    CB_ASSIGN_OR_RETURN(auto event_vector_index, EventVectorToIndex(event_codes, metric_def));

    // If histogram.bucket_indices() is empty, histogram.bucket_counts(i) is the count for the
    // i-th index in the histogram.
    if (histogram.bucket_indices_size() == 0) {
      for (int64_t bucket_index = 0; bucket_index < histogram.bucket_counts_size();
           ++bucket_index) {
        uint64_t histogram_index = HistogramBucketAndCountToIndex(
            histogram.bucket_counts(bucket_index), bucket_index, report_def.max_count(),
            report_def.num_index_points(), gen_.get());
        occurred_indices.push_back(ValueAndEventVectorIndicesToIndex(
            histogram_index, event_vector_index,
            GetNumEventVectors(metric_def.metric_dimensions()) - 1));
      }
    } else {
      // If histogram.bucket_indices() is not empty, histogram.bucket_counts(i) is the count for
      // the histogram.bucket_indices(i)-th index in the histogram.
      for (int i = 0; i < histogram.bucket_indices_size(); ++i) {
        uint64_t histogram_index = HistogramBucketAndCountToIndex(
            histogram.bucket_counts(i), histogram.bucket_indices(i), report_def.max_count(),
            report_def.num_index_points(), gen_.get());
        occurred_indices.push_back(ValueAndEventVectorIndicesToIndex(
            histogram_index, event_vector_index,
            GetNumEventVectors(metric_def.metric_dimensions()) - 1));
      }
    }
  }
  return occurred_indices;
}

}  // namespace cobalt::logger
