blob: 190ad00ae44f5d2f425f684f05b1deced0e40b2a [file] [log] [blame]
//
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "algorithms/internal/count-tree.h"
#include <unordered_map>
#include "absl/status/status.h"
#include "base/statusor.h"
#include "algorithms/algorithm.h"
#include "algorithms/bounded-algorithm.h"
#include "algorithms/numerical-mechanisms.h"
#include "proto/util.h"
#include "proto/summary.pb.h"
#include "base/canonical_errors.h"
#include "base/status_macros.h"
namespace differential_privacy {
namespace internal {
CountTree::CountTree(int height, int branching_factor)
: height_(height),
branching_factor_(branching_factor),
number_of_nodes_((std::pow(branching_factor_, height_ + 1) - 1) /
(branching_factor_ - 1)),
number_of_leaves_(std::pow(branching_factor_, height_)),
left_most_leaf_(number_of_nodes_ - number_of_leaves_) {}
int CountTree::GetLeftMostLeaf() const { return left_most_leaf_; }
int CountTree::GetNthLeaf(int n) const { return GetLeftMostLeaf() + n; }
int CountTree::GetNumberOfNodes() const { return number_of_nodes_; }
int CountTree::GetNumberOfLeaves() const { return number_of_leaves_; }
int CountTree::GetBranchingFactor() const { return branching_factor_; }
int CountTree::GetHeight() const { return height_; }
int CountTree::GetRoot() const { return root_node_; }
int CountTree::Parent(int nodeIndex) const {
return (nodeIndex - 1) / branching_factor_;
}
int CountTree::LeftMostChild(int nodeIndex) const {
return nodeIndex * branching_factor_ + 1;
}
int CountTree::RightMostChild(int nodeIndex) const {
return (nodeIndex + 1) * branching_factor_;
}
bool CountTree::IsLeaf(int nodeIndex) const {
return nodeIndex >= GetLeftMostLeaf() && nodeIndex < GetNumberOfNodes();
}
int CountTree::LeftMostInSubtree(int nodeIndex) const {
while (!IsLeaf(nodeIndex)) {
nodeIndex = LeftMostChild(nodeIndex);
}
return nodeIndex;
}
int CountTree::RightMostInSubtree(int nodeIndex) const {
while (!IsLeaf(nodeIndex)) {
nodeIndex = RightMostChild(nodeIndex);
}
return nodeIndex;
}
void CountTree::IncrementNode(int nodeIndex) { ++tree_[nodeIndex]; }
void CountTree::IncrementNodeBy(int nodeIndex, int64_t increment) {
tree_[nodeIndex] += increment;
}
void CountTree::ClearNodes() { tree_.clear(); }
int64_t CountTree::GetNodeCount(int nodeIndex) const {
auto node = tree_.find(nodeIndex);
if (node == tree_.end()) {
return 0;
}
return node->second;
}
BoundedQuantilesSummary CountTree::Serialize() {
BoundedQuantilesSummary to_return;
to_return.mutable_quantile_tree()->insert(tree_.begin(), tree_.end());
to_return.set_tree_height(height_);
to_return.set_branching_factor(branching_factor_);
return to_return;
}
absl::Status CountTree::Merge(const BoundedQuantilesSummary& summary) {
if (!summary.has_tree_height() || !summary.has_branching_factor()) {
return absl::InternalError(
"Summary missing height and/or branching factor.");
}
if (summary.tree_height() != height_) {
return absl::InternalError(
absl::StrCat("Height mismatch. Tree had: ", height_,
" but summary had: ", summary.tree_height()));
}
if (summary.branching_factor() != branching_factor_) {
return absl::InternalError(
absl::StrCat("Branching factor mismatch. Tree had: ", branching_factor_,
" but summary had: ", summary.branching_factor()));
}
for (std::pair<int32_t, int64_t> node : summary.quantile_tree()) {
tree_[node.first] += node.second;
}
return absl::OkStatus();
}
int64_t CountTree::MemoryUsed() {
int64_t size = sizeof(CountTree);
// Makes some guesses about how unordered_map is likely implemented.
for (int i = 0; i < tree_.bucket_count(); ++i) {
// Pointer to array.
size += sizeof(int64_t*);
// Keys and values in an array.
size += tree_.bucket_size(i) * (sizeof(int64_t) + sizeof(int));
}
return size;
}
} // namespace internal
} // namespace differential_privacy