//
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

// Tool for running a query with ZetaSQL. Supports reading from a csv file.

#include <math.h>

#include <iostream>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "base/logging.h"
#include "google/protobuf/descriptor.h"
#include "zetasql/public/analyzer_options.h"
#include "zetasql/public/catalog.h"
#include "zetasql/public/language_options.h"
#include "zetasql/public/options.pb.h"
#include "zetasql/public/simple_catalog.h"
#include "zetasql/public/type.pb.h"
#include "zetasql/public/value.h"
#include "zetasql/resolved_ast/resolved_ast.h"
#include "zetasql/resolved_ast/resolved_ast_visitor.h"
#include "zetasql/resolved_ast/resolved_node.h"
#include "zetasql/resolved_ast/resolved_node_kind.pb.h"
#include "zetasql/tools/execute_query/execute_query_tool.h"
#include "absl/flags/flag.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/ascii.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/strings/strip.h"
#include "base/status_macros.h"
#include "absl/flags/parse.h"

ABSL_FLAG(std::string, data_set, "",
          "A CSV file containing the data to be queried, whose std::string-typed "
          "column names are determined from the first (header) row. The data "
          "is loaded into a table with the same name as the file (without the "
          ".csv extension, if it exists).");

ABSL_FLAG(
    std::string, userid_col, "",
    "A std::string matching the name of the column in the  containing the user IDs, "
    "to be used in anonoymization queries.");

// Verifies anonymization parameters to be within valid bounds
class VerifyAnonymizationParametersVisitor
    : public zetasql::ResolvedASTVisitor {
 public:
  // We only need a special visitor function for AnonymizationAggregateScans.
  absl::Status VisitResolvedAnonymizedAggregateScan(
      const zetasql::ResolvedAnonymizedAggregateScan* node) override {
    bool epsilon_provided = false;
    bool delta_provided = false;
    bool kappa_provided = false;

    for (auto const& anon_option : node->anonymization_option_list()) {
      // Extract the anonymization option value as a double
      double anon_option_double;

      std::string name = absl::AsciiStrToUpper(anon_option->name());
      const zetasql::ResolvedExpr* anon_option_expr = anon_option->value();

      switch (anon_option_expr->node_kind()) {
        case zetasql::ResolvedNodeKind::RESOLVED_LITERAL: {
          const zetasql::Value anon_option_value =
              anon_option_expr->GetAs<zetasql::ResolvedLiteral>()->value();
          switch (anon_option_value.type_kind()) {
            case zetasql::TypeKind::TYPE_INT64:
              anon_option_double = anon_option_value.int64_value();
              break;
            case zetasql::TypeKind::TYPE_DOUBLE:
              anon_option_double = anon_option_value.double_value();
              break;
            default:  // Unexpected anon_option_value.type_kind()
              return absl::InternalError(absl::StrCat(
                  "Anonymization option ", name,
                  " is expected to be parsed as either an INT64 or",
                  " DOUBLE, but is a " <
                      anon_option_value.type()->ShortTypeName(
                          zetasql::PRODUCT_EXTERNAL),
                  "."));
              break;
          }
          break;
        }
        default: {  // Unexpected anon_option_expr->node_kind()
          return absl::InvalidArgumentError(absl::StrCat(
              "The value of anonymization option", name, " cannot be ",
              "interpreted, since it is not a literal, but is a ",
              anon_option_expr->node_kind_string(), "."));
          break;
        }
      }

      // Return an error if any anonymization parameter is not positive.
      if (anon_option_double <= 0) {
        return absl::InvalidArgumentError(absl::StrCat(
            "Anonymization option ", name, " must be positive, but is ",
            anon_option_double, "."));
      }

      if (absl::EqualsIgnoreCase(name, "epsilon")) {
        epsilon_provided = true;
      }

      if (absl::EqualsIgnoreCase(name, "delta")) {
        delta_provided = true;
        // Return an error if delta is provided and is larger than 1.
        if (anon_option_double > 1) {
          return absl::InvalidArgumentError(absl::StrCat(
              "Anonymization option ", name,
              " must be greater than 0 and less than or equal to 1, but",
              " is ", anon_option_double, "."));
        }
      }

      if (absl::EqualsIgnoreCase(name, "kappa")) {
        kappa_provided = true;
        // Return an error if kappa is specified but is not integer,
        // in case the SQL interpreter did not catch it first.
        double intpart;
        double fraction = modf(anon_option_double, &intpart);
        if (fraction != 0.0) {
          return absl::InvalidArgumentError(absl::StrCat(
              "Anonymization option ", name, " must be an integer, but is ",
              anon_option_double, "."));
        }
      }

      // Return an error k_threshold is specified. Delta should be used instead.
      if (absl::EqualsIgnoreCase(name, "k_threshold")) {
        return absl::InvalidArgumentError(
            "Please use DELTA instead of K_THRESHOLD. DELTA can be"
            " calculated using Theorem 2 of Wilson et al.'s paper on"
            " Differentially Private SQL with Bounded User Contribution"
            " (available at https://arxiv.org/pdf/1909.01917.pdf).");
      }
    }

    // Return an error if epsilon, delta, or kappa are not provided
    if (!(epsilon_provided && delta_provided && kappa_provided)) {
      return absl::InvalidArgumentError(
          "ZetaSQL differential privacy queries must specify EPSILON, "
          " DELTA, and KAPPA in the WITH ANONYMIZATION OPTIONS() clause.");
    }
    return absl::OkStatus();
  }
};

// Returns the file name (without the ".csv", if any) from file_path
static std::string GetCSVFileNameFromPath(const std::string_view file_path) {
  std::vector<std::string> file_path_tokens = absl::StrSplit(file_path, '/');
  std::string_view file_name = file_path_tokens.back();
  absl::ConsumeSuffix(&file_name, ".csv");
  absl::ConsumeSuffix(&file_name, ".CSV");
  return std::string(file_name);
}

static absl::Status InitializeExecuteQueryConfig(
    zetasql::ExecuteQueryConfig& config) {
  config.set_examine_resolved_ast_callback(
      [](const zetasql::ResolvedNode* node) -> absl::Status {
        auto visitor = VerifyAnonymizationParametersVisitor();
        return node->Accept(&visitor);
      });
  config.mutable_catalog().SetDescriptorPool(
      google::protobuf::DescriptorPool::generated_pool());

  RETURN_IF_ERROR(SetToolModeFromFlags(config));

  std::string file_path = absl::GetFlag(FLAGS_data_set);
  std::string table_name = GetCSVFileNameFromPath(file_path);

  ASSIGN_OR_RETURN(std::unique_ptr<zetasql::SimpleTable> table,
                   zetasql::MakeTableFromCsvFile(table_name, file_path));

  const std::string userid_col = absl::GetFlag(FLAGS_userid_col);
  RETURN_IF_ERROR(table->SetAnonymizationInfo({userid_col}));
  config.mutable_analyzer_options().set_enabled_rewrites(
      {zetasql::REWRITE_ANONYMIZATION});

  config.mutable_catalog().AddOwnedTable(std::move(table));

  config.mutable_analyzer_options()
      .mutable_language()
      ->EnableMaximumLanguageFeaturesForDevelopment();
  config.mutable_catalog().AddZetaSQLFunctions(
      config.analyzer_options().language());
  return absl::OkStatus();
}

int main(int argc, char* argv[]) {
  const char kUsage[] =
      "Usage: execute_query --data_set=<path_to_csv_file> "
      "--userid_col=<userid_column_name_in_data_set> <sql_statement>\n";
  std::vector<char*> remaining_args = absl::ParseCommandLine(argc, argv);
  if (argc <= 1) {
    LOG(QFATAL) << kUsage;
  }
  const std::string sql = absl::StrJoin(remaining_args.begin() + 1,
  remaining_args.end(), " ");
  zetasql::ExecuteQueryConfig config;
  absl::Status status = InitializeExecuteQueryConfig(config);
  if (!status.ok()) {
    std::cout << "ERROR: " << status << std::endl;
    return 1;
  }

  auto writer = zetasql::MakeWriterFromFlags(config, std::cout);
  if (!writer.status().ok()) {
    std::cout << "ERROR: " << writer.status() << std::endl;
    return 1;
  }

  status = ExecuteQuery(sql, config, *writer.value());
  if (!status.ok()) {
    std::cout << "ERROR: " << status << std::endl;
    return 1;
  }

  return 0;
}
