| // Copyright 2019 The Fuchsia Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #ifndef COBALT_SRC_ALGORITHMS_PRIVACY_COUNT_MIN_H_ |
| #define COBALT_SRC_ALGORITHMS_PRIVACY_COUNT_MIN_H_ |
| |
| #include <cstdint> |
| #include <string> |
| #include <vector> |
| |
| #include "src/algorithms/privacy/hash.h" |
| #include "src/lib/statusor/statusor.h" |
| #include "src/lib/util/status.h" |
| |
| namespace cobalt { |
| |
| // Implements Count-Min Sketch. |
| // |
| // The count-min sketch represents a distribution of hashable observations as |num_hashes| arrays |
| // of size |num_cells_per_hash|. It relies on a fixed choice of |num_hashes| hash functions mapping |
| // the observation space into the range [0,..., |num_cells_per_hash|]. |
| // |
| // This implementation flattens the representation into a single vector of integer-valued cells. |
| // The cell index range corresponding the kth hash function begins at (k * |num_cells_per_hash|) and |
| // ends (inclusive) at ((k + 1) * |num_cells_per_hash| - 1). |
| // |
| // Incrementing the count for an observation |data| has the effect of incrementing the |num_hashes| |
| // cells in the set |
| // cells(|data|) = { (k * num_cells_per_hash + h_k(|data|)) for k = {1, ..., |num_hashes|} }. |
| // |
| // To estimate the recorded count for |data|, CountMin computes the minimum value of the cells in |
| // cells(|data|). |
| template <typename CountType> |
| class CountMin { |
| public: |
| static_assert(std::is_arithmetic<CountType>::value, |
| "CountMin is only valid for arithmetic types."); |
| |
| // Returns a CountMin sketch with dimensions |num_cells_per_hash| and |num_hashes| and with |
| // |num_cells_per_hash| * |num_hashes| zero-valued cells. |
| static CountMin<CountType> MakeSketch(size_t num_cells_per_hash, size_t num_hashes) { |
| CountMin count_min; |
| count_min.initialize_with_zeros(num_cells_per_hash, num_hashes); |
| return count_min; |
| } |
| |
| // Returns a CountMin sketch with dimensions |num_cells_per_hash| and |num_hashes| containing |
| // |cells| if the size of |cells| is equal to |num_cells_per_hash| * |num_hashes|, or an error |
| // status if not. |
| static lib::statusor::StatusOr<CountMin<CountType>> MakeSketchFromCells( |
| size_t num_cells_per_hash, size_t num_hashes, std::vector<CountType> cells) { |
| CountMin count_min; |
| if (util::Status init_status = |
| count_min.initialize_with_cells(num_cells_per_hash, num_hashes, cells); |
| !init_status.ok()) { |
| return init_status; |
| } |
| return count_min; |
| } |
| |
| // Returns the number of cells in the sketch. |
| [[nodiscard]] size_t size() const { return cells_.size(); } |
| |
| // Increments the number of observations of |data| by |count|. |
| void Increment(const std::string& data, CountType count) { |
| Increment(data.data(), data.size(), count); |
| } |
| |
| // Returns a vector of the |num_hashes| indices corresponding to |data|, without updating the |
| // sketch. |
| [[nodiscard]] std::vector<size_t> GetCellIndices(const std::string& data) const { |
| std::vector<size_t> indices; |
| for (size_t i = 0; i < num_hashes_; ++i) { |
| indices.push_back(GetSketchCell(data.data(), data.size(), i)); |
| } |
| return indices; |
| } |
| |
| // Gets the estimated count for the specified |data|. |
| [[nodiscard]] CountType GetCount(const std::string& data) const { |
| return GetCount(data.data(), data.size()); |
| } |
| |
| // Increments the value at |cell_index| by |count|. Returns an error status if |cell_index| is not |
| // a valid cell index. |
| util::Status IncrementCell(size_t cell_index, CountType count) { |
| if (cell_index >= cells_.size()) { |
| return util::Status(util::OUT_OF_RANGE, "cell index is out of range."); |
| } |
| cells_[cell_index] += count; |
| return util::Status::OK; |
| } |
| |
| // Gets the value of the sketch cell with index |cell_index|. Returns an error status if |
| // |cell_index| is not a valid cell index. |
| [[nodiscard]] lib::statusor::StatusOr<CountType> GetCellValue(size_t cell_index) const { |
| if (cell_index >= cells_.size()) { |
| return util::Status(util::OUT_OF_RANGE, "cell index is out of range."); |
| } |
| return cells_[cell_index]; |
| } |
| |
| private: |
| CountMin() = default; |
| |
| // Sets the dimensions of the sketch and sets |cells_| to a zero vector of length |
| // |num_cells_per_hash| * |num_hashes|. |
| void initialize_with_zeros(size_t num_cells_per_hash, size_t num_hashes) { |
| num_cells_per_hash_ = num_cells_per_hash; |
| num_hashes_ = num_hashes; |
| cells_ = std::vector<CountType>(num_cells_per_hash * num_hashes); |
| } |
| |
| // Sets the dimensions of the sketch and sets |cells_| to the provided vector |cells|. Returns an |
| // error status if the size of |cells| is not equal to |num_cells_per_hash| * |num_hashes|. |
| util::Status initialize_with_cells(size_t num_cells_per_hash, size_t num_hashes, |
| std::vector<CountType> cells) { |
| if (cells.size() != num_cells_per_hash * num_hashes) { |
| return util::Status(util::INVALID_ARGUMENT, |
| "number of cells is not compatible with sketch dimensions."); |
| } |
| num_cells_per_hash_ = num_cells_per_hash; |
| num_hashes_ = num_hashes; |
| cells_ = std::move(cells); |
| return util::Status::OK; |
| } |
| |
| // Increments the number of observations of |data| which has length |len| by |count|. |
| void Increment(const char* data, size_t len, CountType count) { |
| for (size_t i = 0; i < num_hashes_; ++i) { |
| cells_[GetSketchCell(data, len, i)] += count; |
| } |
| } |
| |
| // Gets the estimated count for the specified |data| which has length |len|. |
| [[nodiscard]] CountType GetCount(const char* data, size_t len) const { |
| CountType min_count = 0; |
| for (size_t i = 0; i < num_hashes_; ++i) { |
| CountType count = cells_[GetSketchCell(data, len, i)]; |
| if (count < min_count || i == 0) { |
| min_count = count; |
| } |
| } |
| return min_count; |
| } |
| |
| // Returns the index in the sketch cells of the |data| of length |len| for hash function |
| // |hash_index|. |
| size_t GetSketchCell(const char* data, size_t len, size_t hash_index) const { |
| return hash_index * num_cells_per_hash_ + |
| TruncatedDigest(reinterpret_cast<const uint8_t*>(data), len, hash_index, num_hashes_); |
| } |
| |
| size_t num_cells_per_hash_; |
| size_t num_hashes_; |
| std::vector<CountType> cells_; |
| }; |
| |
| } // namespace cobalt |
| #endif // COBALT_SRC_ALGORITHMS_PRIVACY_COUNT_MIN_H_ |