| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Ops for boosted_trees.""" |
| from tensorflow.python.framework import ops |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import gen_boosted_trees_ops |
| from tensorflow.python.ops import resources |
| |
| # Re-exporting ops used by other modules. |
| # pylint: disable=unused-import |
| from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_aggregate_stats |
| from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_bucketize |
| from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_feature_split as calculate_best_feature_split |
| from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_feature_split_v2 as calculate_best_feature_split_v2 |
| from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_gains_per_feature as calculate_best_gains_per_feature |
| from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_center_bias as center_bias |
| from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_create_quantile_stream_resource as create_quantile_stream_resource |
| from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_example_debug_outputs as example_debug_outputs |
| from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_quantile_summaries as make_quantile_summaries |
| from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_stats_summary as make_stats_summary |
| from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_predict as predict |
| from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_add_summaries as quantile_add_summaries |
| from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_deserialize as quantile_resource_deserialize |
| from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_flush as quantile_flush |
| from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_get_bucket_boundaries as get_bucket_boundaries |
| from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_handle_op as quantile_resource_handle_op |
| from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_sparse_aggregate_stats |
| from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_sparse_calculate_best_feature_split as sparse_calculate_best_feature_split |
| from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_training_predict as training_predict |
| from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_update_ensemble as update_ensemble |
| from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_update_ensemble_v2 as update_ensemble_v2 |
| from tensorflow.python.ops.gen_boosted_trees_ops import is_boosted_trees_quantile_stream_resource_initialized as is_quantile_resource_initialized |
| # pylint: enable=unused-import |
| |
| from tensorflow.python.training import saver |
| |
| |
| class PruningMode: |
| """Class for working with Pruning modes.""" |
| NO_PRUNING, PRE_PRUNING, POST_PRUNING = range(0, 3) |
| |
| _map = {'none': NO_PRUNING, 'pre': PRE_PRUNING, 'post': POST_PRUNING} |
| |
| @classmethod |
| def from_str(cls, mode): |
| if mode in cls._map: |
| return cls._map[mode] |
| else: |
| raise ValueError( |
| 'pruning_mode mode must be one of: {}. Found: {}'.format(', '.join( |
| sorted(cls._map)), mode)) |
| |
| |
| class QuantileAccumulatorSaveable(saver.BaseSaverBuilder.SaveableObject): |
| """SaveableObject implementation for QuantileAccumulator.""" |
| |
| def __init__(self, resource_handle, create_op, num_streams, name): |
| self.resource_handle = resource_handle |
| self._num_streams = num_streams |
| self._create_op = create_op |
| bucket_boundaries = get_bucket_boundaries(self.resource_handle, |
| self._num_streams) |
| slice_spec = '' |
| specs = [] |
| |
| def make_save_spec(tensor, suffix): |
| return saver.BaseSaverBuilder.SaveSpec(tensor, slice_spec, name + suffix) |
| |
| for i in range(self._num_streams): |
| specs += [ |
| make_save_spec(bucket_boundaries[i], '_bucket_boundaries_' + str(i)) |
| ] |
| super(QuantileAccumulatorSaveable, self).__init__(self.resource_handle, |
| specs, name) |
| |
| def restore(self, restored_tensors, unused_tensor_shapes): |
| bucket_boundaries = restored_tensors |
| with ops.control_dependencies([self._create_op]): |
| return quantile_resource_deserialize( |
| self.resource_handle, bucket_boundaries=bucket_boundaries) |
| |
| |
| class QuantileAccumulator(): |
| """SaveableObject implementation for QuantileAccumulator. |
| |
| The bucket boundaries are serialized and deserialized from checkpointing. |
| """ |
| |
| def __init__(self, |
| epsilon, |
| num_streams, |
| num_quantiles, |
| name=None, |
| max_elements=None): |
| del max_elements # Unused. |
| |
| self._eps = epsilon |
| self._num_streams = num_streams |
| self._num_quantiles = num_quantiles |
| |
| with ops.name_scope(name, 'QuantileAccumulator') as name: |
| self._name = name |
| self.resource_handle = self._create_resource() |
| self._init_op = self._initialize() |
| is_initialized_op = self.is_initialized() |
| resources.register_resource(self.resource_handle, self._init_op, |
| is_initialized_op) |
| ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, |
| QuantileAccumulatorSaveable( |
| self.resource_handle, self._init_op, |
| self._num_streams, self.resource_handle.name)) |
| |
| def _create_resource(self): |
| return quantile_resource_handle_op( |
| container='', shared_name=self._name, name=self._name) |
| |
| def _initialize(self): |
| return create_quantile_stream_resource(self.resource_handle, self._eps, |
| self._num_streams) |
| |
| @property |
| def initializer(self): |
| if self._init_op is None: |
| self._init_op = self._initialize() |
| return self._init_op |
| |
| def is_initialized(self): |
| return is_quantile_resource_initialized(self.resource_handle) |
| |
| def _serialize_to_tensors(self): |
| raise NotImplementedError('When the need arises, TF2 compatibility can be ' |
| 'added by implementing this method, along with ' |
| '_restore_from_tensors below.') |
| |
| def _restore_from_tensors(self, restored_tensors): |
| raise NotImplementedError('When the need arises, TF2 compatibility can be ' |
| 'added by implementing this method, along with ' |
| '_serialize_to_tensors above.') |
| |
| def add_summaries(self, float_columns, example_weights): |
| summaries = make_quantile_summaries(float_columns, example_weights, |
| self._eps) |
| summary_op = quantile_add_summaries(self.resource_handle, summaries) |
| return summary_op |
| |
| def flush(self): |
| return quantile_flush(self.resource_handle, self._num_quantiles) |
| |
| def get_bucket_boundaries(self): |
| return get_bucket_boundaries(self.resource_handle, self._num_streams) |
| |
| |
| class _TreeEnsembleSavable(saver.BaseSaverBuilder.SaveableObject): |
| """SaveableObject implementation for TreeEnsemble.""" |
| |
| def __init__(self, resource_handle, create_op, name): |
| """Creates a _TreeEnsembleSavable object. |
| |
| Args: |
| resource_handle: handle to the decision tree ensemble variable. |
| create_op: the op to initialize the variable. |
| name: the name to save the tree ensemble variable under. |
| """ |
| stamp_token, serialized = ( |
| gen_boosted_trees_ops.boosted_trees_serialize_ensemble(resource_handle)) |
| # slice_spec is useful for saving a slice from a variable. |
| # It's not meaningful the tree ensemble variable. So we just pass an empty |
| # value. |
| slice_spec = '' |
| specs = [ |
| saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec, |
| name + '_stamp'), |
| saver.BaseSaverBuilder.SaveSpec(serialized, slice_spec, |
| name + '_serialized'), |
| ] |
| super(_TreeEnsembleSavable, self).__init__(resource_handle, specs, name) |
| self.resource_handle = resource_handle |
| self._create_op = create_op |
| |
| def restore(self, restored_tensors, unused_restored_shapes): |
| """Restores the associated tree ensemble from 'restored_tensors'. |
| |
| Args: |
| restored_tensors: the tensors that were loaded from a checkpoint. |
| unused_restored_shapes: the shapes this object should conform to after |
| restore. Not meaningful for trees. |
| |
| Returns: |
| The operation that restores the state of the tree ensemble variable. |
| """ |
| with ops.control_dependencies([self._create_op]): |
| return gen_boosted_trees_ops.boosted_trees_deserialize_ensemble( |
| self.resource_handle, |
| stamp_token=restored_tensors[0], |
| tree_ensemble_serialized=restored_tensors[1]) |
| |
| |
| class TreeEnsemble(): |
| """Creates TreeEnsemble resource.""" |
| |
| def __init__(self, name, stamp_token=0, is_local=False, serialized_proto=''): |
| self._stamp_token = stamp_token |
| self._serialized_proto = serialized_proto |
| self._is_local = is_local |
| with ops.name_scope(name, 'TreeEnsemble') as name: |
| self._name = name |
| self.resource_handle = self._create_resource() |
| self._init_op = self._initialize() |
| is_initialized_op = self.is_initialized() |
| # Adds the variable to the savable list. |
| if not is_local: |
| ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, |
| _TreeEnsembleSavable( |
| self.resource_handle, self.initializer, |
| self.resource_handle.name)) |
| resources.register_resource( |
| self.resource_handle, |
| self.initializer, |
| is_initialized_op, |
| is_shared=not is_local) |
| |
| def _create_resource(self): |
| return gen_boosted_trees_ops.boosted_trees_ensemble_resource_handle_op( |
| container='', shared_name=self._name, name=self._name) |
| |
| def _initialize(self): |
| return gen_boosted_trees_ops.boosted_trees_create_ensemble( |
| self.resource_handle, |
| self._stamp_token, |
| tree_ensemble_serialized=self._serialized_proto) |
| |
| @property |
| def initializer(self): |
| if self._init_op is None: |
| self._init_op = self._initialize() |
| return self._init_op |
| |
| def is_initialized(self): |
| return gen_boosted_trees_ops.is_boosted_trees_ensemble_initialized( |
| self.resource_handle) |
| |
| def _serialize_to_tensors(self): |
| raise NotImplementedError('When the need arises, TF2 compatibility can be ' |
| 'added by implementing this method, along with ' |
| '_restore_from_tensors below.') |
| |
| def _restore_from_tensors(self, restored_tensors): |
| raise NotImplementedError('When the need arises, TF2 compatibility can be ' |
| 'added by implementing this method, along with ' |
| '_serialize_to_tensors above.') |
| |
| def get_stamp_token(self): |
| """Returns the current stamp token of the resource.""" |
| stamp_token, _, _, _, _ = ( |
| gen_boosted_trees_ops.boosted_trees_get_ensemble_states( |
| self.resource_handle)) |
| return stamp_token |
| |
| def get_states(self): |
| """Returns states of the tree ensemble. |
| |
| Returns: |
| stamp_token, num_trees, num_finalized_trees, num_attempted_layers and |
| range of the nodes in the latest layer. |
| """ |
| (stamp_token, num_trees, num_finalized_trees, num_attempted_layers, |
| nodes_range) = ( |
| gen_boosted_trees_ops.boosted_trees_get_ensemble_states( |
| self.resource_handle)) |
| # Use identity to give names. |
| return (array_ops.identity(stamp_token, name='stamp_token'), |
| array_ops.identity(num_trees, name='num_trees'), |
| array_ops.identity(num_finalized_trees, name='num_finalized_trees'), |
| array_ops.identity( |
| num_attempted_layers, name='num_attempted_layers'), |
| array_ops.identity(nodes_range, name='last_layer_nodes_range')) |
| |
| def serialize(self): |
| """Serializes the ensemble into proto and returns the serialized proto. |
| |
| Returns: |
| stamp_token: int64 scalar Tensor to denote the stamp of the resource. |
| serialized_proto: string scalar Tensor of the serialized proto. |
| """ |
| return gen_boosted_trees_ops.boosted_trees_serialize_ensemble( |
| self.resource_handle) |
| |
| def deserialize(self, stamp_token, serialized_proto): |
| """Deserialize the input proto and resets the ensemble from it. |
| |
| Args: |
| stamp_token: int64 scalar Tensor to denote the stamp of the resource. |
| serialized_proto: string scalar Tensor of the serialized proto. |
| |
| Returns: |
| Operation (for dependencies). |
| """ |
| return gen_boosted_trees_ops.boosted_trees_deserialize_ensemble( |
| self.resource_handle, stamp_token, serialized_proto) |