blob: 9204f9bf0d31fb9c89d524523cdee3a9a7ee8ade [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops for boosted_trees."""
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_boosted_trees_ops
from tensorflow.python.ops import resources
# Re-exporting ops used by other modules.
# pylint: disable=unused-import
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_aggregate_stats
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_bucketize
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_feature_split as calculate_best_feature_split
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_feature_split_v2 as calculate_best_feature_split_v2
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_gains_per_feature as calculate_best_gains_per_feature
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_center_bias as center_bias
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_create_quantile_stream_resource as create_quantile_stream_resource
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_example_debug_outputs as example_debug_outputs
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_quantile_summaries as make_quantile_summaries
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_stats_summary as make_stats_summary
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_predict as predict
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_add_summaries as quantile_add_summaries
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_deserialize as quantile_resource_deserialize
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_flush as quantile_flush
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_get_bucket_boundaries as get_bucket_boundaries
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_handle_op as quantile_resource_handle_op
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_sparse_aggregate_stats
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_sparse_calculate_best_feature_split as sparse_calculate_best_feature_split
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_training_predict as training_predict
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_update_ensemble as update_ensemble
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_update_ensemble_v2 as update_ensemble_v2
from tensorflow.python.ops.gen_boosted_trees_ops import is_boosted_trees_quantile_stream_resource_initialized as is_quantile_resource_initialized
# pylint: enable=unused-import
from tensorflow.python.training import saver
class PruningMode:
"""Class for working with Pruning modes."""
NO_PRUNING, PRE_PRUNING, POST_PRUNING = range(0, 3)
_map = {'none': NO_PRUNING, 'pre': PRE_PRUNING, 'post': POST_PRUNING}
@classmethod
def from_str(cls, mode):
if mode in cls._map:
return cls._map[mode]
else:
raise ValueError(
'pruning_mode mode must be one of: {}. Found: {}'.format(', '.join(
sorted(cls._map)), mode))
class QuantileAccumulatorSaveable(saver.BaseSaverBuilder.SaveableObject):
"""SaveableObject implementation for QuantileAccumulator."""
def __init__(self, resource_handle, create_op, num_streams, name):
self.resource_handle = resource_handle
self._num_streams = num_streams
self._create_op = create_op
bucket_boundaries = get_bucket_boundaries(self.resource_handle,
self._num_streams)
slice_spec = ''
specs = []
def make_save_spec(tensor, suffix):
return saver.BaseSaverBuilder.SaveSpec(tensor, slice_spec, name + suffix)
for i in range(self._num_streams):
specs += [
make_save_spec(bucket_boundaries[i], '_bucket_boundaries_' + str(i))
]
super(QuantileAccumulatorSaveable, self).__init__(self.resource_handle,
specs, name)
def restore(self, restored_tensors, unused_tensor_shapes):
bucket_boundaries = restored_tensors
with ops.control_dependencies([self._create_op]):
return quantile_resource_deserialize(
self.resource_handle, bucket_boundaries=bucket_boundaries)
class QuantileAccumulator():
"""SaveableObject implementation for QuantileAccumulator.
The bucket boundaries are serialized and deserialized from checkpointing.
"""
def __init__(self,
epsilon,
num_streams,
num_quantiles,
name=None,
max_elements=None):
del max_elements # Unused.
self._eps = epsilon
self._num_streams = num_streams
self._num_quantiles = num_quantiles
with ops.name_scope(name, 'QuantileAccumulator') as name:
self._name = name
self.resource_handle = self._create_resource()
self._init_op = self._initialize()
is_initialized_op = self.is_initialized()
resources.register_resource(self.resource_handle, self._init_op,
is_initialized_op)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS,
QuantileAccumulatorSaveable(
self.resource_handle, self._init_op,
self._num_streams, self.resource_handle.name))
def _create_resource(self):
return quantile_resource_handle_op(
container='', shared_name=self._name, name=self._name)
def _initialize(self):
return create_quantile_stream_resource(self.resource_handle, self._eps,
self._num_streams)
@property
def initializer(self):
if self._init_op is None:
self._init_op = self._initialize()
return self._init_op
def is_initialized(self):
return is_quantile_resource_initialized(self.resource_handle)
def _serialize_to_tensors(self):
raise NotImplementedError('When the need arises, TF2 compatibility can be '
'added by implementing this method, along with '
'_restore_from_tensors below.')
def _restore_from_tensors(self, restored_tensors):
raise NotImplementedError('When the need arises, TF2 compatibility can be '
'added by implementing this method, along with '
'_serialize_to_tensors above.')
def add_summaries(self, float_columns, example_weights):
summaries = make_quantile_summaries(float_columns, example_weights,
self._eps)
summary_op = quantile_add_summaries(self.resource_handle, summaries)
return summary_op
def flush(self):
return quantile_flush(self.resource_handle, self._num_quantiles)
def get_bucket_boundaries(self):
return get_bucket_boundaries(self.resource_handle, self._num_streams)
class _TreeEnsembleSavable(saver.BaseSaverBuilder.SaveableObject):
"""SaveableObject implementation for TreeEnsemble."""
def __init__(self, resource_handle, create_op, name):
"""Creates a _TreeEnsembleSavable object.
Args:
resource_handle: handle to the decision tree ensemble variable.
create_op: the op to initialize the variable.
name: the name to save the tree ensemble variable under.
"""
stamp_token, serialized = (
gen_boosted_trees_ops.boosted_trees_serialize_ensemble(resource_handle))
# slice_spec is useful for saving a slice from a variable.
# It's not meaningful the tree ensemble variable. So we just pass an empty
# value.
slice_spec = ''
specs = [
saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec,
name + '_stamp'),
saver.BaseSaverBuilder.SaveSpec(serialized, slice_spec,
name + '_serialized'),
]
super(_TreeEnsembleSavable, self).__init__(resource_handle, specs, name)
self.resource_handle = resource_handle
self._create_op = create_op
def restore(self, restored_tensors, unused_restored_shapes):
"""Restores the associated tree ensemble from 'restored_tensors'.
Args:
restored_tensors: the tensors that were loaded from a checkpoint.
unused_restored_shapes: the shapes this object should conform to after
restore. Not meaningful for trees.
Returns:
The operation that restores the state of the tree ensemble variable.
"""
with ops.control_dependencies([self._create_op]):
return gen_boosted_trees_ops.boosted_trees_deserialize_ensemble(
self.resource_handle,
stamp_token=restored_tensors[0],
tree_ensemble_serialized=restored_tensors[1])
class TreeEnsemble():
"""Creates TreeEnsemble resource."""
def __init__(self, name, stamp_token=0, is_local=False, serialized_proto=''):
self._stamp_token = stamp_token
self._serialized_proto = serialized_proto
self._is_local = is_local
with ops.name_scope(name, 'TreeEnsemble') as name:
self._name = name
self.resource_handle = self._create_resource()
self._init_op = self._initialize()
is_initialized_op = self.is_initialized()
# Adds the variable to the savable list.
if not is_local:
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS,
_TreeEnsembleSavable(
self.resource_handle, self.initializer,
self.resource_handle.name))
resources.register_resource(
self.resource_handle,
self.initializer,
is_initialized_op,
is_shared=not is_local)
def _create_resource(self):
return gen_boosted_trees_ops.boosted_trees_ensemble_resource_handle_op(
container='', shared_name=self._name, name=self._name)
def _initialize(self):
return gen_boosted_trees_ops.boosted_trees_create_ensemble(
self.resource_handle,
self._stamp_token,
tree_ensemble_serialized=self._serialized_proto)
@property
def initializer(self):
if self._init_op is None:
self._init_op = self._initialize()
return self._init_op
def is_initialized(self):
return gen_boosted_trees_ops.is_boosted_trees_ensemble_initialized(
self.resource_handle)
def _serialize_to_tensors(self):
raise NotImplementedError('When the need arises, TF2 compatibility can be '
'added by implementing this method, along with '
'_restore_from_tensors below.')
def _restore_from_tensors(self, restored_tensors):
raise NotImplementedError('When the need arises, TF2 compatibility can be '
'added by implementing this method, along with '
'_serialize_to_tensors above.')
def get_stamp_token(self):
"""Returns the current stamp token of the resource."""
stamp_token, _, _, _, _ = (
gen_boosted_trees_ops.boosted_trees_get_ensemble_states(
self.resource_handle))
return stamp_token
def get_states(self):
"""Returns states of the tree ensemble.
Returns:
stamp_token, num_trees, num_finalized_trees, num_attempted_layers and
range of the nodes in the latest layer.
"""
(stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
nodes_range) = (
gen_boosted_trees_ops.boosted_trees_get_ensemble_states(
self.resource_handle))
# Use identity to give names.
return (array_ops.identity(stamp_token, name='stamp_token'),
array_ops.identity(num_trees, name='num_trees'),
array_ops.identity(num_finalized_trees, name='num_finalized_trees'),
array_ops.identity(
num_attempted_layers, name='num_attempted_layers'),
array_ops.identity(nodes_range, name='last_layer_nodes_range'))
def serialize(self):
"""Serializes the ensemble into proto and returns the serialized proto.
Returns:
stamp_token: int64 scalar Tensor to denote the stamp of the resource.
serialized_proto: string scalar Tensor of the serialized proto.
"""
return gen_boosted_trees_ops.boosted_trees_serialize_ensemble(
self.resource_handle)
def deserialize(self, stamp_token, serialized_proto):
"""Deserialize the input proto and resets the ensemble from it.
Args:
stamp_token: int64 scalar Tensor to denote the stamp of the resource.
serialized_proto: string scalar Tensor of the serialized proto.
Returns:
Operation (for dependencies).
"""
return gen_boosted_trees_ops.boosted_trees_deserialize_ensemble(
self.resource_handle, stamp_token, serialized_proto)