#include "src/lib/privacy/private_index_decoding.h"

#include "src/algorithms/privacy/numeric_encoding.h"
#include "src/lib/statusor/statusor.h"
#include "src/logger/event_vector_index.h"

namespace cobalt {
namespace {

util::Status ValidateIndexAsNumericValue(uint64_t index, uint64_t num_index_points) {
  if (index >= num_index_points) {
    return util::Status(util::INVALID_ARGUMENT,
                        "index is greater than or equal to num_index_points.");
  }
  return util::Status::OK;
}

util::Status ValidateIndexAsCount(uint64_t index, uint64_t num_index_points) {
  if (index < num_index_points) {
    return util::Status(util::INVALID_ARGUMENT, "index is less than num_index_points.");
  }
  if (index >= 2 * num_index_points) {
    return util::Status(util::INVALID_ARGUMENT,
                        "index is greater than or equal to 2 * num_index_points.");
  }
  return util::Status::OK;
}

util::Status ValidateIndexAsHistogramBucketIndex(uint32_t index, uint32_t max_bucket_index) {
  if (index > max_bucket_index) {
    return util::Status(util::INVALID_ARGUMENT, "index is greater than the maximum bucket index.");
  }
  return util::Status::OK;
}

}  // namespace

util::Status DecodePrivateIndexAsEventVector(
    uint64_t index,
    const google::protobuf::RepeatedPtrField<MetricDefinition::MetricDimension>& metric_dimensions,
    std::vector<uint32_t>* event_vector) {
  lib::statusor::StatusOr<std::vector<uint32_t>> decoded_event_vector =
      logger::EventVectorFromIndex(index, metric_dimensions);
  if (decoded_event_vector.ok()) {
    *event_vector = decoded_event_vector.ValueOrDie();
    return util::Status::OK;
  }
  return decoded_event_vector.status();
}

util::Status DecodePrivateIndexAsInteger(
    uint64_t index,
    const google::protobuf::RepeatedPtrField<MetricDefinition::MetricDimension>& metric_dimensions,
    int64_t min_value, int64_t max_value, uint64_t num_index_points,
    std::vector<uint32_t>* event_vector, int64_t* integer_value) {
  double double_value = 0;
  if (util::Status decode_as_double =
          DecodePrivateIndexAsDouble(index, metric_dimensions, min_value, max_value,
                                     num_index_points, event_vector, &double_value);
      !decode_as_double.ok()) {
    return decode_as_double;
  }
  *integer_value = static_cast<int64_t>(double_value);
  return util::Status::OK;
}

util::Status DecodePrivateIndexAsDouble(
    uint64_t index,
    const google::protobuf::RepeatedPtrField<MetricDefinition::MetricDimension>& metric_dimensions,
    int64_t min_value, int64_t max_value, uint64_t num_index_points,
    std::vector<uint32_t>* event_vector, double* double_value) {
  uint64_t event_vector_index = 0;
  uint64_t value_index = 0;
  ValueAndEventVectorIndicesFromIndex(index, logger::GetNumEventVectors(metric_dimensions) - 1,
                                      &value_index, &event_vector_index);
  if (util::Status decode_event_vector_index =
          DecodePrivateIndexAsEventVector(event_vector_index, metric_dimensions, event_vector);
      !decode_event_vector_index.ok()) {
    return decode_event_vector_index;
  }
  if (util::Status validate_value_index =
          ValidateIndexAsNumericValue(value_index, num_index_points);
      !validate_value_index.ok()) {
    return validate_value_index;
  }
  *double_value = DoubleFromIndex(value_index, min_value, max_value, num_index_points);
  return util::Status::OK;
}

cobalt::util::Status DecodePrivateIndexAsSumOrCount(
    uint64_t index,
    const google::protobuf::RepeatedPtrField<MetricDefinition::MetricDimension>& metric_dimensions,
    int64_t min_value, int64_t max_value, uint64_t max_count, uint64_t num_index_points,
    std::vector<uint32_t>* event_vector, SumOrCount* sum_or_count) {
  uint64_t event_vector_index = 0;
  uint64_t value_index = 0;
  ValueAndEventVectorIndicesFromIndex(index, logger::GetNumEventVectors(metric_dimensions) - 1,
                                      &value_index, &event_vector_index);
  if (util::Status decode_event_vector_index =
          DecodePrivateIndexAsEventVector(event_vector_index, metric_dimensions, event_vector);
      !decode_event_vector_index.ok()) {
    return decode_event_vector_index;
  }

  if (IsCountIndex(value_index, num_index_points)) {
    if (util::Status validate_value_index = ValidateIndexAsCount(value_index, num_index_points);
        !validate_value_index.ok()) {
      return validate_value_index;
    }
    (*sum_or_count).type = SumOrCount::COUNT;
    (*sum_or_count).count = CountFromIndex(value_index, max_count, num_index_points);
    return util::Status::OK;
  }

  if (util::Status validate_value_index =
          ValidateIndexAsNumericValue(value_index, num_index_points);
      !validate_value_index.ok()) {
    return validate_value_index;
  }
  (*sum_or_count).type = SumOrCount::SUM;
  (*sum_or_count).sum = IntegerFromIndex(value_index, min_value, max_value, num_index_points);
  return util::Status::OK;
}

util::Status DecodePrivateIndexAsHistogramBucketIndex(
    uint64_t index,
    const google::protobuf::RepeatedPtrField<MetricDefinition::MetricDimension>& metric_dimensions,
    uint32_t max_bucket_index, std::vector<uint32_t>* event_vector, uint32_t* bucket_index) {
  uint64_t event_vector_index = 0;
  uint64_t value_index = 0;
  ValueAndEventVectorIndicesFromIndex(index, logger::GetNumEventVectors(metric_dimensions) - 1,
                                      &value_index, &event_vector_index);
  if (util::Status decode_event_vector_index =
          DecodePrivateIndexAsEventVector(event_vector_index, metric_dimensions, event_vector);
      !decode_event_vector_index.ok()) {
    return decode_event_vector_index;
  }

  *bucket_index = static_cast<uint32_t>(value_index);
  if (util::Status validate_bucket_index =
          ValidateIndexAsHistogramBucketIndex(*bucket_index, max_bucket_index);
      !validate_bucket_index.ok()) {
    return validate_bucket_index;
  }

  return util::Status::OK;
}

util::Status DecodePrivateIndexAsHistogramBucketIndexAndCount(
    uint64_t index,
    const google::protobuf::RepeatedPtrField<MetricDefinition::MetricDimension>& metric_dimensions,
    uint32_t max_bucket_index, uint64_t max_count, uint64_t num_index_points,
    std::vector<uint32_t>* event_vector, uint32_t* bucket_index, uint64_t* bucket_count) {
  uint64_t event_vector_index = 0;
  uint64_t value_index = 0;
  ValueAndEventVectorIndicesFromIndex(index, logger::GetNumEventVectors(metric_dimensions) - 1,
                                      &value_index, &event_vector_index);
  if (util::Status decode_event_vector_index =
          DecodePrivateIndexAsEventVector(event_vector_index, metric_dimensions, event_vector);
      !decode_event_vector_index.ok()) {
    return decode_event_vector_index;
  }

  HistogramBucketAndCountFromIndex(value_index, max_count, num_index_points, bucket_index,
                                   bucket_count);
  if (util::Status validate_bucket_index =
          ValidateIndexAsHistogramBucketIndex(*bucket_index, max_bucket_index);
      !validate_bucket_index.ok()) {
    return validate_bucket_index;
  }

  return util::Status::OK;
}

}  // namespace cobalt
