blob: a176e9bbed7899b40c4cc6a3962d93b5f8d7a5e3 [file] [log] [blame] [edit]
# Copyright 2021 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LSH tree for hierarchically grouping nearby points."""
import dataclasses
import functools
from typing import Dict, List, Optional
from absl import logging
import numpy as np
from clustering import central_privacy_utils
from clustering import clustering_params
from clustering import coreset_params
from clustering import lsh
HashPrefix = str
@dataclasses.dataclass
class LshTreeNode():
"""Node in a LSH tree corresponding to a single hash prefix.
Attributes:
hash_prefix: Hash prefix represented by this node.
nonprivate_points: Points that hash to hash_prefix.
coreset_param: Clustering params used for constructing this node.
sim_hash: LSH used for generating the hashes.
private_count: Private count of the points in nonprivate_points.
"""
hash_prefix: HashPrefix
nonprivate_points: np.ndarray
coreset_param: coreset_params.CoresetParam
sim_hash: lsh.SimHash
private_count: Optional[int] = None
def __post_init__(self):
if self.private_count is None:
self.get_private_count()
@functools.cached_property
def private_average(self) -> np.ndarray:
"""Returns and saves private average of the points in the node.
Requires that self.private_count >= 1.
"""
return central_privacy_utils.get_private_average(
self.nonprivate_points,
self.private_count,
self.coreset_param.pcalc.average_privacy_param,
self.sim_hash.dim,
)
def get_private_count(self) -> int:
"""Returns and saves private count of the points in the node."""
if self.private_count is not None:
return self.private_count
self.private_count = central_privacy_utils.get_private_count(
len(self.nonprivate_points),
self.coreset_param.pcalc.count_privacy_param)
return self.private_count
def children(self) -> List["LshTreeNode"]:
"""Returns the children for this node.
There is a child for every hash_prefix equal to self.hash_prefix with one
more hash character. Note that children are returned regardless of
self.coreset_param.tree_param.
"""
next_hash_char_to_points = self.sim_hash.group_by_next_hash(
self.nonprivate_points, hash_prefix=self.hash_prefix)
return [
LshTreeNode(self.hash_prefix + next_hash_char,
nonprivate_points_with_hash_char, self.coreset_param,
self.sim_hash) for next_hash_char,
nonprivate_points_with_hash_char in next_hash_char_to_points.items()
]
def __repr__(self) -> str:
"""Represents nodes in the form of private_count(hash_prefix)."""
return str(self.private_count) + "(" + self.hash_prefix + ")"
# List of leaves in the LSH tree.
LshTreeLeaves = List[LshTreeNode]
# Index of a particular level of the tree
LevelIndex = int
# Nodes on one levels of the tree
LshTreeLevel = List[LshTreeNode]
# Subset of a level including just the nodes that should be branched.
NodesToBranch = LshTreeLevel
def root_node(data: clustering_params.Data,
coreset_param: coreset_params.CoresetParam,
private_count: Optional[int] = None):
"""Returns root node for an LSH prefix tree.
Args:
data: Data to use for generating the tree.
coreset_param: Clustering parameters to use for generating the tree.
private_count: Private count for the number of datapoints. If None, the
private count will be computed.
"""
sim_hash = lsh.SimHash(data.dim, coreset_param.tree_param.max_depth)
return LshTreeNode(
"", data.datapoints, coreset_param, sim_hash, private_count=private_count)
class LshTree():
"""Tree in which the data is split into groups based on prefixes from the LSH values.
Attributes:
tree: Maps level indices to each level.
leaves: Leaf nodes of the tree.
"""
tree: Dict[LevelIndex, LshTreeLevel]
leaves: LshTreeLeaves
def __init__(self, root: LshTreeNode):
"""Initializes an LshTree with the given root.
Args:
root: Root to use for the LshTree. Required to have private count >= 1.
"""
if root.private_count < 1:
raise ValueError("Private count of the root must be at least 1.")
coreset_param = root.coreset_param
logging.debug("Starting tree construction with max_levels %s",
coreset_param.tree_param.max_depth)
level_idx: LevelIndex = 0
self.tree: Dict[LevelIndex, LshTreeLevel] = dict()
self.tree[level_idx] = [root]
while level_idx < coreset_param.tree_param.max_depth:
# Branch all the nodes that should be branched
branching_nodes: NodesToBranch = LshTree.filter_branching_nodes(
self.tree[level_idx])
next_level = LshTree.get_next_level(branching_nodes)
if next_level:
level_idx += 1
self.tree[level_idx] = next_level
else:
break
logging.debug("Tree generated (level -> nodes): %s", self.tree)
logging.debug("Starting to collect the leaves of the tree.")
self.leaves = []
for level_idx in self.tree:
self.leaves.extend(list(filter(self.is_leaf, self.tree[level_idx])))
logging.debug("Found %s leaves: %s", len(self.leaves), self.leaves)
def is_leaf(self, node: LshTreeNode) -> bool:
"""Returns whether the node is a leaf.
Args:
node: LshTreeNode in this tree to check whether it has any children.
"""
level_below = len(node.hash_prefix) + 1
# If node is in the last level, it is a leaf.
if level_below > max(self.tree.keys()):
return True
for maybe_child in self.tree.get(level_below):
# Each level adds one character to the hash prefix.
if node.hash_prefix == maybe_child.hash_prefix[:-1]:
return False
# No children were found, so this node is a leaf.
return True
@staticmethod
def filter_branching_nodes(tree_level: LshTreeLevel) -> NodesToBranch:
"""Returns the nodes in tree_level that have enough points to branch.
Args:
tree_level: A level of the tree.
"""
def enough_points_to_branch(node: LshTreeNode):
tree_param = node.coreset_param.tree_param
return node.private_count >= tree_param.min_num_points_in_branching_node
return list(filter(enough_points_to_branch, tree_level))
@staticmethod
def get_next_level(nodes_to_branch: NodesToBranch) -> LshTreeLevel:
"""Returns the next level of the tree based on nodes_to_branch.
Args:
nodes_to_branch: Nodes to branch for getting the next level in the tree.
"""
flatten_children = []
for node in nodes_to_branch:
flatten_children.extend(node.children())
def enough_points(node: LshTreeNode):
tree_param = node.coreset_param.tree_param
return node.private_count >= tree_param.min_num_points_in_node
return list(filter(enough_points, flatten_children))