blob: 2bff3b7b92afcdd8a9776204c65caa132c4d9125 [file] [log] [blame]
//
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Tool for running a query with ZetaSQL. Supports reading from a csv file.
#include <math.h>
#include <iostream>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "base/logging.h"
#include "google/protobuf/descriptor.h"
#include "zetasql/public/analyzer_options.h"
#include "zetasql/public/catalog.h"
#include "zetasql/public/language_options.h"
#include "zetasql/public/options.pb.h"
#include "zetasql/public/simple_catalog.h"
#include "zetasql/public/type.pb.h"
#include "zetasql/public/value.h"
#include "zetasql/resolved_ast/resolved_ast.h"
#include "zetasql/resolved_ast/resolved_ast_visitor.h"
#include "zetasql/resolved_ast/resolved_node.h"
#include "zetasql/resolved_ast/resolved_node_kind.pb.h"
#include "zetasql/tools/execute_query/execute_query_tool.h"
#include "absl/flags/flag.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/ascii.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/strings/strip.h"
#include "base/status_macros.h"
#include "absl/flags/parse.h"
ABSL_FLAG(std::string, data_set, "",
"A CSV file containing the data to be queried, whose std::string-typed "
"column names are determined from the first (header) row. The data "
"is loaded into a table with the same name as the file (without the "
".csv extension, if it exists).");
ABSL_FLAG(
std::string, userid_col, "",
"A std::string matching the name of the column in the containing the user IDs, "
"to be used in anonoymization queries.");
// Verifies anonymization parameters to be within valid bounds
class VerifyAnonymizationParametersVisitor
: public zetasql::ResolvedASTVisitor {
public:
// We only need a special visitor function for AnonymizationAggregateScans.
absl::Status VisitResolvedAnonymizedAggregateScan(
const zetasql::ResolvedAnonymizedAggregateScan* node) override {
bool epsilon_provided = false;
bool delta_provided = false;
bool kappa_provided = false;
for (auto const& anon_option : node->anonymization_option_list()) {
// Extract the anonymization option value as a double
double anon_option_double;
std::string name = absl::AsciiStrToUpper(anon_option->name());
const zetasql::ResolvedExpr* anon_option_expr = anon_option->value();
switch (anon_option_expr->node_kind()) {
case zetasql::ResolvedNodeKind::RESOLVED_LITERAL: {
const zetasql::Value anon_option_value =
anon_option_expr->GetAs<zetasql::ResolvedLiteral>()->value();
switch (anon_option_value.type_kind()) {
case zetasql::TypeKind::TYPE_INT64:
anon_option_double = anon_option_value.int64_value();
break;
case zetasql::TypeKind::TYPE_DOUBLE:
anon_option_double = anon_option_value.double_value();
break;
default: // Unexpected anon_option_value.type_kind()
return absl::InternalError(absl::StrCat(
"Anonymization option ", name,
" is expected to be parsed as either an INT64 or",
" DOUBLE, but is a " <
anon_option_value.type()->ShortTypeName(
zetasql::PRODUCT_EXTERNAL),
"."));
break;
}
break;
}
default: { // Unexpected anon_option_expr->node_kind()
return absl::InvalidArgumentError(absl::StrCat(
"The value of anonymization option", name, " cannot be ",
"interpreted, since it is not a literal, but is a ",
anon_option_expr->node_kind_string(), "."));
break;
}
}
// Return an error if any anonymization parameter is not positive.
if (anon_option_double <= 0) {
return absl::InvalidArgumentError(absl::StrCat(
"Anonymization option ", name, " must be positive, but is ",
anon_option_double, "."));
}
if (absl::EqualsIgnoreCase(name, "epsilon")) {
epsilon_provided = true;
}
if (absl::EqualsIgnoreCase(name, "delta")) {
delta_provided = true;
// Return an error if delta is provided and is larger than 1.
if (anon_option_double > 1) {
return absl::InvalidArgumentError(absl::StrCat(
"Anonymization option ", name,
" must be greater than 0 and less than or equal to 1, but",
" is ", anon_option_double, "."));
}
}
if (absl::EqualsIgnoreCase(name, "kappa")) {
kappa_provided = true;
// Return an error if kappa is specified but is not integer,
// in case the SQL interpreter did not catch it first.
double intpart;
double fraction = modf(anon_option_double, &intpart);
if (fraction != 0.0) {
return absl::InvalidArgumentError(absl::StrCat(
"Anonymization option ", name, " must be an integer, but is ",
anon_option_double, "."));
}
}
// Return an error k_threshold is specified. Delta should be used instead.
if (absl::EqualsIgnoreCase(name, "k_threshold")) {
return absl::InvalidArgumentError(
"Please use DELTA instead of K_THRESHOLD. DELTA can be"
" calculated using Theorem 2 of Wilson et al.'s paper on"
" Differentially Private SQL with Bounded User Contribution"
" (available at https://arxiv.org/pdf/1909.01917.pdf).");
}
}
// Return an error if epsilon, delta, or kappa are not provided
if (!(epsilon_provided && delta_provided && kappa_provided)) {
return absl::InvalidArgumentError(
"ZetaSQL differential privacy queries must specify EPSILON, "
" DELTA, and KAPPA in the WITH ANONYMIZATION OPTIONS() clause.");
}
return absl::OkStatus();
}
};
// Returns the file name (without the ".csv", if any) from file_path
static std::string GetCSVFileNameFromPath(const std::string_view file_path) {
std::vector<std::string> file_path_tokens = absl::StrSplit(file_path, '/');
std::string_view file_name = file_path_tokens.back();
absl::ConsumeSuffix(&file_name, ".csv");
absl::ConsumeSuffix(&file_name, ".CSV");
return std::string(file_name);
}
static absl::Status InitializeExecuteQueryConfig(
zetasql::ExecuteQueryConfig& config) {
config.set_examine_resolved_ast_callback(
[](const zetasql::ResolvedNode* node) -> absl::Status {
auto visitor = VerifyAnonymizationParametersVisitor();
return node->Accept(&visitor);
});
config.mutable_catalog().SetDescriptorPool(
google::protobuf::DescriptorPool::generated_pool());
RETURN_IF_ERROR(SetToolModeFromFlags(config));
std::string file_path = absl::GetFlag(FLAGS_data_set);
std::string table_name = GetCSVFileNameFromPath(file_path);
ASSIGN_OR_RETURN(std::unique_ptr<zetasql::SimpleTable> table,
zetasql::MakeTableFromCsvFile(table_name, file_path));
const std::string userid_col = absl::GetFlag(FLAGS_userid_col);
RETURN_IF_ERROR(table->SetAnonymizationInfo({userid_col}));
config.mutable_analyzer_options().set_enabled_rewrites(
{zetasql::REWRITE_ANONYMIZATION});
config.mutable_catalog().AddOwnedTable(std::move(table));
config.mutable_analyzer_options()
.mutable_language()
->EnableMaximumLanguageFeaturesForDevelopment();
config.mutable_catalog().AddZetaSQLFunctions(
config.analyzer_options().language());
return absl::OkStatus();
}
int main(int argc, char* argv[]) {
const char kUsage[] =
"Usage: execute_query --data_set=<path_to_csv_file> "
"--userid_col=<userid_column_name_in_data_set> <sql_statement>\n";
std::vector<char*> remaining_args = absl::ParseCommandLine(argc, argv);
if (argc <= 1) {
LOG(QFATAL) << kUsage;
}
const std::string sql = absl::StrJoin(remaining_args.begin() + 1,
remaining_args.end(), " ");
zetasql::ExecuteQueryConfig config;
absl::Status status = InitializeExecuteQueryConfig(config);
if (!status.ok()) {
std::cout << "ERROR: " << status << std::endl;
return 1;
}
auto writer = zetasql::MakeWriterFromFlags(config, std::cout);
if (!writer.status().ok()) {
std::cout << "ERROR: " << writer.status() << std::endl;
return 1;
}
status = ExecuteQuery(sql, config, *writer.value());
if (!status.ok()) {
std::cout << "ERROR: " << status << std::endl;
return 1;
}
return 0;
}