blob: 8fae4313bf166d2e4e0c95681f03171338dc9bb7 [file] [log] [blame]
# Copyright 2021 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LSH tree for hierarchically grouping nearby points."""
import dataclasses
import typing
from absl import logging
import numpy as np
from clustering import central_privacy_utils
from clustering import clustering_params
from clustering import coreset_params
from clustering import lsh
HashPrefix = str
@dataclasses.dataclass
class LshTreeNode():
"""Node in a LSH tree corresponding to a single hash prefix.
Attributes:
hash_prefix: Hash prefix represented by this node.
nonprivate_points: Points that hash to hash_prefix.
coreset_param: Clustering params used for constructing this node.
sim_hash: LSH used for generating the hashes.
private_count: Private count of the points in nonprivate_points.
private_average: Private average of the points in nonprivate_points if
get_private_average has been called in the past, otherwise None.
"""
hash_prefix: HashPrefix
nonprivate_points: np.ndarray
coreset_param: coreset_params.CoresetParam
sim_hash: lsh.SimHash
private_count: typing.Optional[int] = None
private_average: typing.Optional[np.ndarray] = dataclasses.field(
init=False, default=None)
def __post_init__(self):
if self.private_count is None:
self.get_private_count()
def get_private_average(self) -> np.ndarray:
"""Returns and saves private average of the points in the node.
Requires that self.private_count >= 1.
"""
# Reuse old results if they've been computed in the past.
if self.private_average is not None:
return self.private_average
self.private_average = central_privacy_utils.get_private_average(
self.nonprivate_points, self.private_count,
self.coreset_param.pcalc.average_privacy_param, self.sim_hash.dim)
return self.private_average
def get_private_count(self) -> int:
"""Returns and saves private count of the points in the node."""
if self.private_count is not None:
return self.private_count
self.private_count = central_privacy_utils.get_private_count(
len(self.nonprivate_points),
self.coreset_param.pcalc.count_privacy_param)
return self.private_count
def children(self) -> typing.List["LshTreeNode"]:
"""Returns the children for this node.
There is a child for every hash_prefix equal to self.hash_prefix with one
more hash character. Note that children are returned regardless of
self.coreset_param.tree_param.
"""
next_hash_char_to_points = self.sim_hash.group_by_next_hash(
self.nonprivate_points, hash_prefix=self.hash_prefix)
return [
LshTreeNode(self.hash_prefix + next_hash_char,
nonprivate_points_with_hash_char, self.coreset_param,
self.sim_hash) for next_hash_char,
nonprivate_points_with_hash_char in next_hash_char_to_points.items()
]
def __repr__(self) -> str:
"""Represents nodes in the form of private_count(hash_prefix)."""
return str(self.private_count) + "(" + self.hash_prefix + ")"
# List of leaves in the LSH tree.
LshTreeLeaves = typing.List[LshTreeNode]
# Index of a particular level of the tree
LevelIndex = int
# Nodes on one levels of the tree
LshTreeLevel = typing.List[LshTreeNode]
# Subset of a level including just the nodes that should be branched.
NodesToBranch = LshTreeLevel
def root_node(data: clustering_params.Data,
coreset_param: coreset_params.CoresetParam,
private_count: typing.Optional[int] = None):
"""Returns root node for an LSH prefix tree.
Args:
data: Data to use for generating the tree.
coreset_param: Clustering parameters to use for generating the tree.
private_count: Private count for the number of datapoints. If None, the
private count will be computed.
"""
sim_hash = lsh.SimHash(data.dim, coreset_param.tree_param.max_depth)
return LshTreeNode(
"", data.datapoints, coreset_param, sim_hash, private_count=private_count)
class LshTree():
"""Tree in which the data is split into groups based on prefixes from the LSH values.
Attributes:
tree: Maps level indices to each level.
leaves: Leaf nodes of the tree.
"""
tree: typing.Dict[LevelIndex, LshTreeLevel]
leaves: LshTreeLeaves
def __init__(self, root: LshTreeNode):
"""Initializes an LshTree with the given root.
Args:
root: Root to use for the LshTree. Required to have private count >= 1.
"""
if root.private_count < 1:
raise ValueError("Private count of the root must be at least 1.")
coreset_param = root.coreset_param
logging.debug("Starting tree construction with max_levels %s",
coreset_param.tree_param.max_depth)
level_idx: LevelIndex = 0
self.tree: typing.Dict[LevelIndex, LshTreeLevel] = dict()
self.tree[level_idx] = [root]
while level_idx < coreset_param.tree_param.max_depth:
# Branch all the nodes that should be branched
branching_nodes: NodesToBranch = LshTree.filter_branching_nodes(
self.tree[level_idx])
next_level = LshTree.get_next_level(branching_nodes)
if next_level:
level_idx += 1
self.tree[level_idx] = next_level
else:
break
logging.debug("Tree generated (level -> nodes): %s", self.tree)
logging.debug("Starting to collect the leaves of the tree.")
self.leaves = []
for level_idx in self.tree:
self.leaves.extend(list(filter(self.is_leaf, self.tree[level_idx])))
logging.debug("Found %s leaves: %s", len(self.leaves), self.leaves)
def is_leaf(self, node: LshTreeNode) -> bool:
"""Returns whether the node is a leaf.
Args:
node: LshTreeNode in this tree to check whether it has any children.
"""
level_below = len(node.hash_prefix) + 1
# If node is in the last level, it is a leaf.
if level_below > max(self.tree.keys()):
return True
for maybe_child in self.tree.get(level_below):
# Each level adds one character to the hash prefix.
if node.hash_prefix == maybe_child.hash_prefix[:-1]:
return False
# No children were found, so this node is a leaf.
return True
@staticmethod
def filter_branching_nodes(tree_level: LshTreeLevel) -> NodesToBranch:
"""Returns the nodes in tree_level that have enough points to branch.
Args:
tree_level: A level of the tree.
"""
def enough_points_to_branch(node: LshTreeNode):
tree_param = node.coreset_param.tree_param
return node.private_count >= tree_param.min_num_points_in_branching_node
return list(filter(enough_points_to_branch, tree_level))
@staticmethod
def get_next_level(nodes_to_branch: NodesToBranch) -> LshTreeLevel:
"""Returns the next level of the tree based on nodes_to_branch.
Args:
nodes_to_branch: Nodes to branch for getting the next level in the tree.
"""
flatten_children = []
for node in nodes_to_branch:
flatten_children.extend(node.children())
def enough_points(node: LshTreeNode):
tree_param = node.coreset_param.tree_param
return node.private_count >= tree_param.min_num_points_in_node
return list(filter(enough_points, flatten_children))