| # Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Clustering Operations.""" |
| |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import random_seed as random_seed_ops |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import check_ops |
| from tensorflow.python.ops import cond |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import gen_clustering_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import nn_impl |
| from tensorflow.python.ops import random_ops |
| from tensorflow.python.ops import state_ops |
| from tensorflow.python.ops import variable_v1 |
| from tensorflow.python.ops import while_loop |
| from tensorflow.python.ops.embedding_ops import embedding_lookup |
| # go/tf-wildcard-import |
| # pylint: disable=wildcard-import |
| from tensorflow.python.ops.gen_clustering_ops import * |
| # pylint: enable=wildcard-import |
| |
| # Euclidean distance between vectors U and V is defined as \\(||U - V||_F\\) |
| # which is the square root of the sum of the absolute squares of the elements |
| # difference. |
| SQUARED_EUCLIDEAN_DISTANCE = 'squared_euclidean' |
| # Cosine distance between vectors U and V is defined as |
| # \\(1 - (U \dot V) / (||U||_F ||V||_F)\\) |
| COSINE_DISTANCE = 'cosine' |
| |
| RANDOM_INIT = 'random' |
| KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus' |
| KMC2_INIT = 'kmc2' |
| |
| # The name of the variable holding the cluster centers. Used by the Estimator. |
| CLUSTERS_VAR_NAME = 'clusters' |
| |
| |
| class KMeans: |
| """Creates the graph for k-means clustering.""" |
| |
| def __init__(self, |
| inputs, |
| num_clusters, |
| initial_clusters=RANDOM_INIT, |
| distance_metric=SQUARED_EUCLIDEAN_DISTANCE, |
| use_mini_batch=False, |
| mini_batch_steps_per_iteration=1, |
| random_seed=0, |
| kmeans_plus_plus_num_retries=2, |
| kmc2_chain_length=200): |
| """Creates an object for generating KMeans clustering graph. |
| |
| This class implements the following variants of K-means algorithm: |
| |
| If use_mini_batch is False, it runs standard full batch K-means. Each step |
| runs a single iteration of K-Means. This step can be run sharded across |
| multiple workers by passing a list of sharded inputs to this class. Note |
| however that a single step needs to process the full input at once. |
| |
| If use_mini_batch is True, it runs a generalization of the mini-batch |
| K-means algorithm. It runs multiple iterations, where each iteration is |
| composed of mini_batch_steps_per_iteration steps. Two copies of cluster |
| centers are maintained: one that is updated at the end of each iteration, |
| and one that is updated every step. The first copy is used to compute |
| cluster allocations for each step, and for inference, while the second copy |
| is the one updated each step using the mini-batch update rule. After each |
| iteration is complete, this second copy is copied back the first copy. |
| |
| Note that for use_mini_batch=True, when mini_batch_steps_per_iteration=1, |
| the algorithm reduces to the standard mini-batch algorithm. Also by setting |
| mini_batch_steps_per_iteration = num_inputs / batch_size, the algorithm |
| becomes an asynchronous version of the full-batch algorithm. Note however |
| that there is no guarantee by this implementation that each input is seen |
| exactly once per iteration. Also, different updates are applied |
| asynchronously without locking. So this asynchronous version may not behave |
| exactly like a full-batch version. |
| |
| Args: |
| inputs: An input tensor or list of input tensors. It is assumed that the |
| data points have been previously randomly permuted. |
| num_clusters: An integer tensor specifying the number of clusters. This |
| argument is ignored if initial_clusters is a tensor or numpy array. |
| initial_clusters: Specifies the clusters used during initialization. One |
| of the following: - a tensor or numpy array with the initial cluster |
| centers. - a function f(inputs, k) that returns up to k centers from |
| `inputs`. |
| - "random": Choose centers randomly from `inputs`. |
| - "kmeans_plus_plus": Use kmeans++ to choose centers from `inputs`. |
| - "kmc2": Use the fast k-MC2 algorithm to choose centers from `inputs`. |
| In the last three cases, one batch of `inputs` may not yield |
| `num_clusters` centers, in which case initialization will require |
| multiple batches until enough centers are chosen. In the case of |
| "random" or "kmeans_plus_plus", if the input size is <= `num_clusters` |
| then the entire batch is chosen to be cluster centers. |
| distance_metric: Distance metric used for clustering. Supported options: |
| "squared_euclidean", "cosine". |
| use_mini_batch: If true, use the mini-batch k-means algorithm. Else assume |
| full batch. |
| mini_batch_steps_per_iteration: Number of steps after which the updated |
| cluster centers are synced back to a master copy. |
| random_seed: Seed for PRNG used to initialize seeds. |
| kmeans_plus_plus_num_retries: For each point that is sampled during |
| kmeans++ initialization, this parameter specifies the number of |
| additional points to draw from the current distribution before selecting |
| the best. If a negative value is specified, a heuristic is used to |
| sample O(log(num_to_sample)) additional points. |
| kmc2_chain_length: Determines how many candidate points are used by the |
| k-MC2 algorithm to produce one new cluster centers. If a (mini-)batch |
| contains less points, one new cluster center is generated from the |
| (mini-)batch. |
| |
| Raises: |
| ValueError: An invalid argument was passed to initial_clusters or |
| distance_metric. |
| """ |
| initialization_algorithms = [RANDOM_INIT, KMEANS_PLUS_PLUS_INIT, KMC2_INIT] |
| if isinstance(initial_clusters, |
| str) and initial_clusters not in initialization_algorithms: |
| raise ValueError( |
| f'Unsupported initialization algorithm `{initial_clusters}`,' |
| f'must be one of `{initialization_algorithms}`.') |
| |
| distance_metrics = [SQUARED_EUCLIDEAN_DISTANCE, COSINE_DISTANCE] |
| if distance_metric not in distance_metrics: |
| raise ValueError(f'Unsupported distance metric `{distance_metric}`,' |
| f'must be one of `{distance_metrics}`.') |
| self._inputs = inputs if isinstance(inputs, list) else [inputs] |
| self._num_clusters = num_clusters |
| self._initial_clusters = initial_clusters |
| self._distance_metric = distance_metric |
| self._use_mini_batch = use_mini_batch |
| self._mini_batch_steps_per_iteration = int(mini_batch_steps_per_iteration) |
| self._seed = random_seed_ops.get_seed(random_seed)[0] |
| self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries |
| self._kmc2_chain_length = kmc2_chain_length |
| |
| @classmethod |
| def _distance_graph(cls, inputs, clusters, distance_metric): |
| """Computes distance between each input and each cluster center. |
| |
| Args: |
| inputs: list of input Tensors. |
| clusters: cluster Tensor. |
| distance_metric: distance metric used for clustering |
| |
| Returns: |
| list of Tensors, where each element corresponds to each element in inputs. |
| The value is the distance of each row to all the cluster centers. |
| Currently only Euclidean distance and cosine distance are supported. |
| """ |
| assert isinstance(inputs, list) |
| if distance_metric == SQUARED_EUCLIDEAN_DISTANCE: |
| return cls._compute_euclidean_distance(inputs, clusters) |
| elif distance_metric == COSINE_DISTANCE: |
| return cls._compute_cosine_distance( |
| inputs, clusters, inputs_normalized=True) |
| else: |
| assert False, str(distance_metric) |
| |
| @classmethod |
| def _compute_euclidean_distance(cls, inputs, clusters): |
| """Computes Euclidean distance between each input and each cluster center. |
| |
| Args: |
| inputs: list of input Tensors. |
| clusters: cluster Tensor. |
| |
| Returns: |
| list of Tensors, where each element corresponds to each element in inputs. |
| The value is the distance of each row to all the cluster centers. |
| """ |
| output = [] |
| for inp in inputs: |
| with ops.colocate_with(inp, ignore_existing=True): |
| # Computes Euclidean distance. Note the first and third terms are |
| # broadcast additions. |
| squared_distance = ( |
| math_ops.reduce_sum(math_ops.square(inp), 1, keepdims=True) - |
| 2 * math_ops.matmul(inp, clusters, transpose_b=True) + |
| array_ops.transpose( |
| math_ops.reduce_sum( |
| math_ops.square(clusters), 1, keepdims=True))) |
| output.append(squared_distance) |
| |
| return output |
| |
| @classmethod |
| def _compute_cosine_distance(cls, inputs, clusters, inputs_normalized=True): |
| """Computes cosine distance between each input and each cluster center. |
| |
| Args: |
| inputs: list of input Tensor. |
| clusters: cluster Tensor |
| inputs_normalized: if True, it assumes that inp and clusters are |
| normalized and computes the dot product which is equivalent to the |
| cosine distance. Else it L2 normalizes the inputs first. |
| |
| Returns: |
| list of Tensors, where each element corresponds to each element in inp. |
| The value is the distance of each row to all the cluster centers. |
| """ |
| output = [] |
| if not inputs_normalized: |
| with ops.colocate_with(clusters, ignore_existing=True): |
| clusters = nn_impl.l2_normalize(clusters, axis=1) |
| for inp in inputs: |
| with ops.colocate_with(inp, ignore_existing=True): |
| if not inputs_normalized: |
| inp = nn_impl.l2_normalize(inp, axis=1) |
| output.append(1 - math_ops.matmul(inp, clusters, transpose_b=True)) |
| return output |
| |
| def _infer_graph(self, inputs, clusters): |
| """Maps input to closest cluster and the score. |
| |
| Args: |
| inputs: list of input Tensors. |
| clusters: Tensor of cluster centers. |
| |
| Returns: |
| List of tuple, where each value in tuple corresponds to a value in inp. |
| The tuple has following three elements: |
| all_scores: distance of each input to each cluster center. |
| score: distance of each input to closest cluster center. |
| cluster_idx: index of cluster center closest to the corresponding input. |
| """ |
| assert isinstance(inputs, list) |
| # Pairwise distances are used only by transform(). In all other cases, this |
| # sub-graph is not evaluated. |
| scores = self._distance_graph(inputs, clusters, self._distance_metric) |
| output = [] |
| if (self._distance_metric == COSINE_DISTANCE and |
| not self._clusters_l2_normalized()): |
| # The cosine distance between normalized vectors x and y is the same as |
| # 2 * squared_euclidean_distance. We are using this fact and reusing the |
| # nearest_neighbors op. |
| # TODO(ands): Support COSINE distance in nearest_neighbors and remove |
| # this. |
| with ops.colocate_with(clusters, ignore_existing=True): |
| clusters = nn_impl.l2_normalize(clusters, axis=1) |
| for inp, score in zip(inputs, scores): |
| with ops.colocate_with(inp, ignore_existing=True): |
| (indices, |
| distances) = gen_clustering_ops.nearest_neighbors(inp, clusters, 1) |
| if self._distance_metric == COSINE_DISTANCE: |
| distances *= 0.5 |
| output.append( |
| (score, array_ops.squeeze(distances, |
| [-1]), array_ops.squeeze(indices, [-1]))) |
| return zip(*output) |
| |
| def _clusters_l2_normalized(self): |
| """Returns True if clusters centers are kept normalized.""" |
| return (self._distance_metric == COSINE_DISTANCE and |
| (not self._use_mini_batch or |
| self._mini_batch_steps_per_iteration > 1)) |
| |
| def _create_variables(self, num_clusters): |
| """Creates variables. |
| |
| Args: |
| num_clusters: an integer Tensor providing the number of clusters. |
| |
| Returns: |
| Tuple with following elements: |
| - cluster_centers: a Tensor for storing cluster centers |
| - cluster_centers_initialized: bool Variable indicating whether clusters |
| are initialized. |
| - cluster_counts: a Tensor for storing counts of points assigned to this |
| cluster. This is used by mini-batch training. |
| - cluster_centers_updated: Tensor representing copy of cluster centers |
| that are updated every step. |
| - update_in_steps: numbers of steps left before we sync |
| cluster_centers_updated back to cluster_centers. |
| """ |
| init_value = array_ops.placeholder_with_default([], shape=None) |
| cluster_centers = variable_v1.VariableV1( |
| init_value, name=CLUSTERS_VAR_NAME, validate_shape=False) |
| cluster_centers_initialized = variable_v1.VariableV1( |
| False, dtype=dtypes.bool, name='initialized') |
| |
| if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1: |
| # Copy of cluster centers actively updated each step according to |
| # mini-batch update rule. |
| cluster_centers_updated = variable_v1.VariableV1( |
| init_value, name='clusters_updated', validate_shape=False) |
| # How many steps till we copy the updated clusters to cluster_centers. |
| update_in_steps = variable_v1.VariableV1( |
| self._mini_batch_steps_per_iteration, |
| dtype=dtypes.int64, |
| name='update_in_steps') |
| # Count of points assigned to cluster_centers_updated. |
| cluster_counts = variable_v1.VariableV1( |
| array_ops.zeros([num_clusters], dtype=dtypes.int64)) |
| else: |
| cluster_centers_updated = cluster_centers |
| update_in_steps = None |
| cluster_counts = ( |
| variable_v1.VariableV1( |
| array_ops.ones([num_clusters], dtype=dtypes.int64)) |
| if self._use_mini_batch else None) |
| return (cluster_centers, cluster_centers_initialized, cluster_counts, |
| cluster_centers_updated, update_in_steps) |
| |
| @classmethod |
| def _l2_normalize_data(cls, inputs): |
| """Normalized the input data.""" |
| output = [] |
| for inp in inputs: |
| with ops.colocate_with(inp, ignore_existing=True): |
| output.append(nn_impl.l2_normalize(inp, dim=1)) |
| return output |
| |
| def training_graph(self): |
| """Generate a training graph for kmeans algorithm. |
| |
| This returns, among other things, an op that chooses initial centers |
| (init_op), a boolean variable that is set to True when the initial centers |
| are chosen (cluster_centers_initialized), and an op to perform either an |
| entire Lloyd iteration or a mini-batch of a Lloyd iteration (training_op). |
| The caller should use these components as follows. A single worker should |
| execute init_op multiple times until cluster_centers_initialized becomes |
| True. Then multiple workers may execute training_op any number of times. |
| |
| Returns: |
| A tuple consisting of: |
| all_scores: A matrix (or list of matrices) of dimensions (num_input, |
| num_clusters) where the value is the distance of an input vector and a |
| cluster center. |
| cluster_idx: A vector (or list of vectors). Each element in the vector |
| corresponds to an input row in 'inp' and specifies the cluster id |
| corresponding to the input. |
| scores: Similar to cluster_idx but specifies the distance to the |
| assigned cluster instead. |
| cluster_centers_initialized: scalar indicating whether clusters have been |
| initialized. |
| init_op: an op to initialize the clusters. |
| training_op: an op that runs an iteration of training. |
| """ |
| # Implementation of kmeans. |
| if (isinstance(self._initial_clusters, str) or |
| callable(self._initial_clusters)): |
| initial_clusters = self._initial_clusters |
| num_clusters = ops.convert_to_tensor(self._num_clusters) |
| else: |
| initial_clusters = ops.convert_to_tensor(self._initial_clusters) |
| num_clusters = array_ops.shape(initial_clusters)[0] |
| |
| inputs = self._inputs |
| (cluster_centers_var, cluster_centers_initialized, total_counts, |
| cluster_centers_updated, |
| update_in_steps) = self._create_variables(num_clusters) |
| init_op = _InitializeClustersOpFactory( |
| self._inputs, num_clusters, initial_clusters, self._distance_metric, |
| self._seed, self._kmeans_plus_plus_num_retries, self._kmc2_chain_length, |
| cluster_centers_var, cluster_centers_updated, |
| cluster_centers_initialized).op() |
| cluster_centers = cluster_centers_var |
| |
| if self._distance_metric == COSINE_DISTANCE: |
| inputs = self._l2_normalize_data(inputs) |
| if not self._clusters_l2_normalized(): |
| cluster_centers = nn_impl.l2_normalize(cluster_centers, dim=1) |
| |
| all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers) |
| if self._use_mini_batch: |
| sync_updates_op = self._mini_batch_sync_updates_op( |
| update_in_steps, cluster_centers_var, cluster_centers_updated, |
| total_counts) |
| assert sync_updates_op is not None |
| with ops.control_dependencies([sync_updates_op]): |
| training_op = self._mini_batch_training_op(inputs, cluster_idx, |
| cluster_centers_updated, |
| total_counts) |
| else: |
| assert cluster_centers == cluster_centers_var |
| training_op = self._full_batch_training_op(inputs, num_clusters, |
| cluster_idx, |
| cluster_centers_var) |
| |
| return (all_scores, cluster_idx, scores, cluster_centers_initialized, |
| init_op, training_op) |
| |
| def _mini_batch_sync_updates_op(self, update_in_steps, cluster_centers_var, |
| cluster_centers_updated, total_counts): |
| if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1: |
| assert update_in_steps is not None |
| with ops.colocate_with(update_in_steps, ignore_existing=True): |
| |
| def _f(): |
| # Note that there is a race condition here, so we do a best effort |
| # updates here. We reset update_in_steps first so that other workers |
| # don't duplicate the updates. Also we update cluster_center_vars |
| # before resetting total_counts to avoid large updates to |
| # cluster_centers_updated based on partially updated |
| # cluster_center_vars. |
| with ops.control_dependencies([ |
| state_ops.assign(update_in_steps, |
| self._mini_batch_steps_per_iteration - 1) |
| ]): |
| with ops.colocate_with( |
| cluster_centers_updated, ignore_existing=True): |
| if self._distance_metric == COSINE_DISTANCE: |
| cluster_centers = nn_impl.l2_normalize( |
| cluster_centers_updated, dim=1) |
| else: |
| cluster_centers = cluster_centers_updated |
| with ops.colocate_with(cluster_centers_var, ignore_existing=True): |
| with ops.control_dependencies( |
| [state_ops.assign(cluster_centers_var, cluster_centers)]): |
| with ops.colocate_with(None, ignore_existing=True): |
| with ops.control_dependencies([ |
| state_ops.assign(total_counts, |
| array_ops.zeros_like(total_counts)) |
| ]): |
| return array_ops.identity(update_in_steps) |
| |
| return cond.cond( |
| update_in_steps <= 0, _f, |
| lambda: state_ops.assign_sub(update_in_steps, 1)) |
| else: |
| return control_flow_ops.no_op() |
| |
| def _mini_batch_training_op(self, inputs, cluster_idx_list, cluster_centers, |
| total_counts): |
| """Creates an op for training for mini batch case. |
| |
| Args: |
| inputs: list of input Tensors. |
| cluster_idx_list: A vector (or list of vectors). Each element in the |
| vector corresponds to an input row in 'inp' and specifies the cluster id |
| corresponding to the input. |
| cluster_centers: Tensor Ref of cluster centers. |
| total_counts: Tensor Ref of cluster counts. |
| |
| Returns: |
| An op for doing an update of mini-batch k-means. |
| """ |
| update_ops = [] |
| for inp, cluster_idx in zip(inputs, cluster_idx_list): |
| with ops.colocate_with(inp, ignore_existing=True): |
| assert total_counts is not None |
| cluster_idx = array_ops.reshape(cluster_idx, [-1]) |
| # Dedupe the unique ids of cluster_centers being updated so that updates |
| # can be locally aggregated. |
| unique_ids, unique_idx = array_ops.unique(cluster_idx) |
| num_unique_cluster_idx = array_ops.size(unique_ids) |
| # Fetch the old values of counts and cluster_centers. |
| with ops.colocate_with(total_counts, ignore_existing=True): |
| old_counts = array_ops.gather(total_counts, unique_ids) |
| # TODO(agarwal): This colocation seems to run into problems. Fix it. |
| with ops.colocate_with(cluster_centers, ignore_existing=True): |
| old_cluster_centers = array_ops.gather(cluster_centers, unique_ids) |
| # Locally aggregate the increment to counts. |
| count_updates = math_ops.unsorted_segment_sum( |
| array_ops.ones_like(unique_idx, dtype=total_counts.dtype), |
| unique_idx, num_unique_cluster_idx) |
| # Locally compute the sum of inputs mapped to each id. |
| # For a cluster with old cluster value x, old count n, and with data |
| # d_1,...d_k newly assigned to it, we recompute the new value as |
| # \\(x += (sum_i(d_i) - k * x) / (n + k)\\). |
| # Compute \\(sum_i(d_i)\\), see comment above. |
| cluster_center_updates = math_ops.unsorted_segment_sum( |
| inp, unique_idx, num_unique_cluster_idx) |
| # Shape to enable broadcasting count_updates and learning_rate to inp. |
| # It extends the shape with 1's to match the rank of inp. |
| broadcast_shape = array_ops.concat([ |
| array_ops.reshape(num_unique_cluster_idx, [1]), |
| array_ops.ones( |
| array_ops.reshape(array_ops.rank(inp) - 1, [1]), |
| dtype=dtypes.int32) |
| ], 0) |
| # Subtract k * x, see comment above. |
| cluster_center_updates -= math_ops.cast( |
| array_ops.reshape(count_updates, broadcast_shape), |
| inp.dtype) * old_cluster_centers |
| learning_rate = math_ops.reciprocal( |
| math_ops.cast(old_counts + count_updates, inp.dtype)) |
| learning_rate = array_ops.reshape(learning_rate, broadcast_shape) |
| # scale by 1 / (n + k), see comment above. |
| cluster_center_updates *= learning_rate |
| # Apply the updates. |
| update_counts = state_ops.scatter_add(total_counts, unique_ids, |
| count_updates) |
| update_cluster_centers = state_ops.scatter_add(cluster_centers, |
| unique_ids, |
| cluster_center_updates) |
| update_ops.extend([update_counts, update_cluster_centers]) |
| return control_flow_ops.group(*update_ops) |
| |
| def _full_batch_training_op(self, inputs, num_clusters, cluster_idx_list, |
| cluster_centers): |
| """Creates an op for training for full batch case. |
| |
| Args: |
| inputs: list of input Tensors. |
| num_clusters: an integer Tensor providing the number of clusters. |
| cluster_idx_list: A vector (or list of vectors). Each element in the |
| vector corresponds to an input row in 'inp' and specifies the cluster id |
| corresponding to the input. |
| cluster_centers: Tensor Ref of cluster centers. |
| |
| Returns: |
| An op for doing an update of mini-batch k-means. |
| """ |
| cluster_sums = [] |
| cluster_counts = [] |
| epsilon = constant_op.constant(1e-6, dtype=inputs[0].dtype) |
| for inp, cluster_idx in zip(inputs, cluster_idx_list): |
| with ops.colocate_with(inp, ignore_existing=True): |
| cluster_sums.append( |
| math_ops.unsorted_segment_sum(inp, cluster_idx, num_clusters)) |
| cluster_counts.append( |
| math_ops.unsorted_segment_sum( |
| array_ops.reshape( |
| array_ops.ones( |
| array_ops.reshape(array_ops.shape(inp)[0], [-1])), |
| [-1, 1]), cluster_idx, num_clusters)) |
| with ops.colocate_with(cluster_centers, ignore_existing=True): |
| new_clusters_centers = math_ops.add_n(cluster_sums) / ( |
| math_ops.cast(math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + |
| epsilon) |
| if self._clusters_l2_normalized(): |
| new_clusters_centers = nn_impl.l2_normalize(new_clusters_centers, dim=1) |
| return state_ops.assign(cluster_centers, new_clusters_centers) |
| |
| |
| class _InitializeClustersOpFactory: |
| """Internal class to create the op to initialize the clusters. |
| |
| The op performs this algorithm (see constructor args): |
| |
| num_remaining = num_clusters - length(cluster_centers) |
| if num_remaining == 0: |
| assert that cluster_centers_initialized is true |
| else: |
| assert that num_remaining > 0 |
| new_centers = choose up to num_remaining initial centers |
| l2-normalize new_centers if using cosine distance |
| all_centers = concat(cluster_centers, new_centers) |
| cluster_centers := all_centers |
| if there is a cluster_centers_updated variable: |
| cluster_centers_updated := cluster_centers |
| num_now_remaining = num_clusters - length(cluster_centers) |
| if num_now_remaining == 0: |
| cluster_centers_initialized := true |
| """ |
| |
| # TODO(ccolby): Refactor this class so that kmc2 isn't so much a special case. |
| |
| def __init__(self, inputs, num_clusters, initial_clusters, distance_metric, |
| random_seed, kmeans_plus_plus_num_retries, kmc2_chain_length, |
| cluster_centers, cluster_centers_updated, |
| cluster_centers_initialized): |
| """Creates an op factory. |
| |
| Args: |
| inputs: See KMeans constructor. |
| num_clusters: An integer Tensor providing the number of clusters. |
| initial_clusters: See KMeans constructor. |
| distance_metric: See KMeans constructor. |
| random_seed: See KMeans constructor. |
| kmeans_plus_plus_num_retries: See KMeans constructor. |
| kmc2_chain_length: See KMeans constructor. |
| cluster_centers: The TF variable holding the initial centers. It may |
| already contain some centers when the op is executed. |
| cluster_centers_updated: A second TF variable to hold a copy of the |
| initial centers, used for full-batch mode. In mini-batch mode, |
| cluster_centers_updated is the same variable as cluster_centers. |
| cluster_centers_initialized: A boolean TF variable that will be set to |
| true when all the initial centers have been chosen. |
| """ |
| # All of these instance variables are constants. |
| self._inputs = inputs |
| self._num_clusters = num_clusters |
| self._initial_clusters = initial_clusters |
| self._distance_metric = distance_metric |
| self._seed = random_seed |
| self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries |
| self._kmc2_chain_length = kmc2_chain_length |
| self._cluster_centers = cluster_centers |
| self._cluster_centers_updated = cluster_centers_updated |
| self._cluster_centers_initialized = cluster_centers_initialized |
| |
| self._num_selected = array_ops.shape(self._cluster_centers)[0] |
| self._num_remaining = self._num_clusters - self._num_selected |
| self._num_data = math_ops.add_n( |
| [array_ops.shape(i)[0] for i in self._inputs]) |
| |
| def _random(self): |
| indices = random_ops.random_uniform( |
| array_ops.reshape(self._num_remaining, [-1]), |
| minval=0, |
| maxval=math_ops.cast(self._num_data, dtypes.int64), |
| seed=self._seed, |
| dtype=dtypes.int64) |
| return embedding_lookup(self._inputs, indices, partition_strategy='div') |
| |
| def _kmeans_plus_plus(self): |
| # Points from only the first shard are used for initializing centers. |
| # TODO(ands): Use all points. |
| inp = self._inputs[0] |
| if self._distance_metric == COSINE_DISTANCE: |
| inp = nn_impl.l2_normalize(inp, dim=1) |
| return gen_clustering_ops.kmeans_plus_plus_initialization( |
| inp, math_ops.cast(self._num_remaining, dtypes.int64), self._seed, |
| self._kmeans_plus_plus_num_retries) |
| |
| def _kmc2_multiple_centers(self): |
| """Adds new initial cluster centers using the k-MC2 algorithm. |
| |
| In each call to the op, the provided batch is split into subsets based on |
| the specified `kmc2_chain_length`. On each subset, a single Markov chain of |
| the k-MC2 algorithm is used to add *one* new center cluster center. If there |
| are less than `kmc2_chain_length` points in the subset, a single center is |
| added using one Markov chain on the full input. It is assumed that the |
| provided batch has previously been randomly permuted. Otherwise, k-MC2 may |
| return suboptimal centers. |
| |
| Returns: |
| An op that adds new cluster centers. |
| """ |
| # The op only operates on the first shard of data. |
| first_shard = self._inputs[0] |
| # Number of points in the input that can be used. |
| batch_size = array_ops.shape(first_shard)[0] |
| # Maximum number of subsets such that the size of each subset is at least |
| # `kmc2_chain_length`. Final subsets may be larger. |
| max_to_sample = math_ops.cast( |
| batch_size / self._kmc2_chain_length, dtype=dtypes.int32) |
| # We sample at least one new center and at most all remaining centers. |
| num_to_sample = math_ops.maximum( |
| math_ops.minimum(self._num_remaining, max_to_sample), 1) |
| |
| def _cond(i, _): |
| """Stopping condition for the while loop.""" |
| return math_ops.less(i, num_to_sample) |
| |
| def _body(i, _): |
| """Body that adds a single new center based on a subset.""" |
| |
| def _sample_random(): |
| """Returns a random point as a cluster center.""" |
| # By assumption the batch is reshuffled and _sample_random is always |
| # called for i=0. Hence, we simply return the first point. |
| new_center = array_ops.reshape(first_shard[0], [1, -1]) |
| if self._distance_metric == COSINE_DISTANCE: |
| new_center = nn_impl.l2_normalize(new_center, dim=1) |
| return new_center |
| |
| def _sample_kmc2_chain(): |
| """Returns previous centers as well as a new center sampled using k-MC2.""" |
| # Extract the subset from the underlying batch. |
| start = i * self._kmc2_chain_length |
| end = start + self._kmc2_chain_length |
| subset = first_shard[start:end] |
| # Compute the distances from points in the subset to previous centers. |
| _, distances = gen_clustering_ops.nearest_neighbors( |
| subset, self._cluster_centers, 1) |
| # Sample index of new center using k-MC2 Markov chain. |
| new_center_index = gen_clustering_ops.kmc2_chain_initialization( |
| array_ops.squeeze(distances), self._seed) |
| # Extract actual new center. |
| newly_sampled_center = array_ops.reshape(subset[new_center_index], |
| [1, -1]) |
| # Return concatenation with previously sampled centers. |
| if self._distance_metric == COSINE_DISTANCE: |
| newly_sampled_center = nn_impl.l2_normalize( |
| newly_sampled_center, dim=1) |
| return array_ops.concat([self._cluster_centers, newly_sampled_center], |
| 0) |
| |
| # Obtain a random point if there are no previously sampled centers. |
| # Otherwise, construct a k-MC2 Markov chain. |
| new_centers = cond.cond( |
| math_ops.equal(self._num_selected, 0), _sample_random, |
| _sample_kmc2_chain) |
| # Assign new cluster centers to underlying variable. |
| assigned_centers = state_ops.assign( |
| self._cluster_centers, new_centers, validate_shape=False) |
| if self._cluster_centers_updated is not self._cluster_centers: |
| assigned_centers = state_ops.assign( |
| self._cluster_centers_updated, |
| assigned_centers, |
| validate_shape=False) |
| return i + 1, self._num_clusters - array_ops.shape(assigned_centers)[0] |
| |
| # Add num_to_sample new data points. |
| _, num_remaining = while_loop.while_loop(_cond, _body, [0, 0]) |
| return num_remaining |
| |
| def _greedy_batch_sampler(self, sampler): |
| # If the input dataset size is smaller than the number of centers |
| # remaining, choose the entire input dataset as centers. This can happen |
| # with mini-batch. Otherwise, sample the batch according to the provided |
| # sampler. |
| return cond.cond(self._num_data <= self._num_remaining, |
| lambda: array_ops.concat(self._inputs, 0), |
| sampler) |
| |
| def _single_batch_sampler(self, sampler): |
| # Enforce that there are at least as many data points as centers |
| # remaining. This gives the provided sampler the chance to select all |
| # remaining centers from a single batch. |
| with ops.control_dependencies( |
| [check_ops.assert_greater_equal(self._num_data, self._num_remaining)]): |
| return sampler() |
| |
| def _choose_initial_centers(self): |
| if isinstance(self._initial_clusters, str): |
| if self._initial_clusters == RANDOM_INIT: |
| return self._greedy_batch_sampler(self._random) |
| else: # self._initial_clusters == KMEANS_PLUS_PLUS_INIT |
| return self._single_batch_sampler(self._kmeans_plus_plus) |
| elif callable(self._initial_clusters): |
| return self._initial_clusters(self._inputs, self._num_remaining) |
| else: |
| with ops.control_dependencies([ |
| check_ops.assert_equal(self._num_remaining, |
| array_ops.shape(self._initial_clusters)[0]) |
| ]): |
| return self._initial_clusters |
| |
| def _add_new_centers(self): |
| """Adds some centers and returns the number of centers remaining.""" |
| new_centers = self._choose_initial_centers() |
| if self._distance_metric == COSINE_DISTANCE: |
| new_centers = nn_impl.l2_normalize(new_centers, dim=1) |
| # If cluster_centers is empty, it doesn't have the right shape for concat. |
| all_centers = cond.cond( |
| math_ops.equal(self._num_selected, 0), lambda: new_centers, |
| lambda: array_ops.concat([self._cluster_centers, new_centers], 0)) |
| # TODO(ccolby): De-dupe all_centers? |
| a = state_ops.assign( |
| self._cluster_centers, all_centers, validate_shape=False) |
| if self._cluster_centers_updated is not self._cluster_centers: |
| a = state_ops.assign( |
| self._cluster_centers_updated, a, validate_shape=False) |
| return self._num_clusters - array_ops.shape(a)[0] |
| |
| def _initialize(self): |
| with ops.control_dependencies([ |
| check_ops.assert_positive(self._num_remaining), |
| ]): |
| if self._initial_clusters == KMC2_INIT: |
| num_now_remaining = self._kmc2_multiple_centers() |
| else: |
| num_now_remaining = self._add_new_centers() |
| return cond.cond( |
| math_ops.equal(num_now_remaining, 0), |
| lambda: state_ops.assign(self._cluster_centers_initialized, True), |
| control_flow_ops.no_op) |
| |
| def op(self): |
| """Returns the cluster initializer op.""" |
| return cond.cond( |
| math_ops.equal(self._num_remaining, 0), |
| lambda: check_ops.assert_equal(self._cluster_centers_initialized, True), |
| self._initialize) |