blob: 37aca4ff3e8b62d7850214306d7e06c9fa658988 [file] [log] [blame]
// Copyright 2019 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COBALT_SRC_ALGORITHMS_PRIVACY_COUNT_MIN_H_
#define COBALT_SRC_ALGORITHMS_PRIVACY_COUNT_MIN_H_
#include <cstdint>
#include <string>
#include <vector>
#include "src/algorithms/privacy/hash.h"
#include "src/public/lib/status.h"
#include "src/public/lib/statusor/statusor.h"
namespace cobalt {
// Implements Count-Min Sketch.
//
// The count-min sketch represents a distribution of hashable observations as |num_hashes| arrays
// of size |num_cells_per_hash|. It relies on a fixed choice of |num_hashes| hash functions mapping
// the observation space into the range [0,..., |num_cells_per_hash|].
//
// This implementation flattens the representation into a single vector of integer-valued cells.
// The cell index range corresponding the kth hash function begins at (k * |num_cells_per_hash|) and
// ends (inclusive) at ((k + 1) * |num_cells_per_hash| - 1).
//
// Incrementing the count for an observation |data| has the effect of incrementing the |num_hashes|
// cells in the set
// cells(|data|) = { (k * num_cells_per_hash + h_k(|data|)) for k = {1, ..., |num_hashes|} }.
//
// To estimate the recorded count for |data|, CountMin computes the minimum value of the cells in
// cells(|data|).
template <typename CountType>
class CountMin {
public:
static_assert(std::is_arithmetic<CountType>::value,
"CountMin is only valid for arithmetic types.");
// Returns a CountMin sketch with dimensions |num_cells_per_hash| and |num_hashes| and with
// |num_cells_per_hash| * |num_hashes| zero-valued cells.
static CountMin<CountType> MakeSketch(size_t num_cells_per_hash, size_t num_hashes) {
CountMin count_min;
count_min.initialize_with_zeros(num_cells_per_hash, num_hashes);
return count_min;
}
// Returns a CountMin sketch with dimensions |num_cells_per_hash| and |num_hashes| containing
// |cells| if the size of |cells| is equal to |num_cells_per_hash| * |num_hashes|, or an error
// status if not.
static lib::statusor::StatusOr<CountMin<CountType>> MakeSketchFromCells(
size_t num_cells_per_hash, size_t num_hashes, std::vector<CountType> cells) {
CountMin count_min;
if (Status init_status = count_min.initialize_with_cells(num_cells_per_hash, num_hashes, cells);
!init_status.ok()) {
return init_status;
}
return count_min;
}
// Returns the number of cells in the sketch.
[[nodiscard]] size_t size() const { return cells_.size(); }
// Increments the number of observations of |data| by |count|.
void Increment(const std::string& data, CountType count) {
Increment(data.data(), data.size(), count);
}
// Returns a vector of the |num_hashes| indices corresponding to |data|, without updating the
// sketch.
[[nodiscard]] std::vector<size_t> GetCellIndices(const std::string& data) const {
std::vector<size_t> indices;
for (size_t i = 0; i < num_hashes_; ++i) {
indices.push_back(GetSketchCell(data.data(), data.size(), i));
}
return indices;
}
// Gets the estimated count for the specified |data|.
[[nodiscard]] CountType GetCount(const std::string& data) const {
return GetCount(data.data(), data.size());
}
// Increments the value at |cell_index| by |count|. Returns an error status if |cell_index| is not
// a valid cell index.
Status IncrementCell(size_t cell_index, CountType count) {
if (cell_index >= cells_.size()) {
return Status(StatusCode::OUT_OF_RANGE, "cell index is out of range.");
}
cells_[cell_index] += count;
return Status::OkStatus();
}
// Gets the value of the sketch cell with index |cell_index|. Returns an error status if
// |cell_index| is not a valid cell index.
[[nodiscard]] lib::statusor::StatusOr<CountType> GetCellValue(size_t cell_index) const {
if (cell_index >= cells_.size()) {
return Status(StatusCode::OUT_OF_RANGE, "cell index is out of range.");
}
return cells_[cell_index];
}
private:
CountMin() = default;
// Sets the dimensions of the sketch and sets |cells_| to a zero vector of length
// |num_cells_per_hash| * |num_hashes|.
void initialize_with_zeros(size_t num_cells_per_hash, size_t num_hashes) {
num_cells_per_hash_ = num_cells_per_hash;
num_hashes_ = num_hashes;
cells_ = std::vector<CountType>(num_cells_per_hash * num_hashes);
}
// Sets the dimensions of the sketch and sets |cells_| to the provided vector |cells|. Returns an
// error status if the size of |cells| is not equal to |num_cells_per_hash| * |num_hashes|.
Status initialize_with_cells(size_t num_cells_per_hash, size_t num_hashes,
std::vector<CountType> cells) {
if (cells.size() != num_cells_per_hash * num_hashes) {
return Status(StatusCode::INVALID_ARGUMENT,
"number of cells is not compatible with sketch dimensions.");
}
num_cells_per_hash_ = num_cells_per_hash;
num_hashes_ = num_hashes;
cells_ = std::move(cells);
return Status::OkStatus();
}
// Increments the number of observations of |data| which has length |len| by |count|.
void Increment(const char* data, size_t len, CountType count) {
for (size_t i = 0; i < num_hashes_; ++i) {
cells_[GetSketchCell(data, len, i)] += count;
}
}
// Gets the estimated count for the specified |data| which has length |len|.
[[nodiscard]] CountType GetCount(const char* data, size_t len) const {
CountType min_count = 0;
for (size_t i = 0; i < num_hashes_; ++i) {
CountType count = cells_[GetSketchCell(data, len, i)];
if (count < min_count || i == 0) {
min_count = count;
}
}
return min_count;
}
// Returns the index in the sketch cells of the |data| of length |len| for hash function
// |hash_index|.
size_t GetSketchCell(const char* data, size_t len, size_t hash_index) const {
return hash_index * num_cells_per_hash_ +
TruncatedDigest(reinterpret_cast<const uint8_t*>(data), len, hash_index, num_hashes_);
}
size_t num_cells_per_hash_;
size_t num_hashes_;
std::vector<CountType> cells_;
};
} // namespace cobalt
#endif // COBALT_SRC_ALGORITHMS_PRIVACY_COUNT_MIN_H_