1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ops for boosted_trees.""" 16from tensorflow.python.framework import ops 17from tensorflow.python.ops import array_ops 18from tensorflow.python.ops import gen_boosted_trees_ops 19from tensorflow.python.ops import resources 20 21# Re-exporting ops used by other modules. 22# pylint: disable=unused-import 23from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_aggregate_stats 24from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_bucketize 25from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_feature_split as calculate_best_feature_split 26from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_feature_split_v2 as calculate_best_feature_split_v2 27from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_gains_per_feature as calculate_best_gains_per_feature 28from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_center_bias as center_bias 29from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_create_quantile_stream_resource as create_quantile_stream_resource 30from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_example_debug_outputs as example_debug_outputs 31from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_quantile_summaries as make_quantile_summaries 32from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_stats_summary as make_stats_summary 33from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_predict as predict 34from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_add_summaries as quantile_add_summaries 35from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_deserialize as quantile_resource_deserialize 36from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_flush as quantile_flush 37from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_get_bucket_boundaries as get_bucket_boundaries 38from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_handle_op as quantile_resource_handle_op 39from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_sparse_aggregate_stats 40from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_sparse_calculate_best_feature_split as sparse_calculate_best_feature_split 41from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_training_predict as training_predict 42from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_update_ensemble as update_ensemble 43from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_update_ensemble_v2 as update_ensemble_v2 44from tensorflow.python.ops.gen_boosted_trees_ops import is_boosted_trees_quantile_stream_resource_initialized as is_quantile_resource_initialized 45# pylint: enable=unused-import 46 47from tensorflow.python.trackable import resource 48from tensorflow.python.training import saver 49 50 51class PruningMode: 52 """Class for working with Pruning modes.""" 53 NO_PRUNING, PRE_PRUNING, POST_PRUNING = range(0, 3) 54 55 _map = {'none': NO_PRUNING, 'pre': PRE_PRUNING, 'post': POST_PRUNING} 56 57 @classmethod 58 def from_str(cls, mode): 59 if mode in cls._map: 60 return cls._map[mode] 61 else: 62 raise ValueError( 63 'pruning_mode mode must be one of: {}. Found: {}'.format(', '.join( 64 sorted(cls._map)), mode)) 65 66 67class QuantileAccumulatorSaveable(saver.BaseSaverBuilder.SaveableObject): 68 """SaveableObject implementation for QuantileAccumulator.""" 69 70 def __init__(self, resource_handle, create_op, num_streams, name): 71 self._resource_handle = resource_handle 72 self._num_streams = num_streams 73 self._create_op = create_op 74 bucket_boundaries = get_bucket_boundaries(self._resource_handle, 75 self._num_streams) 76 slice_spec = '' 77 specs = [] 78 79 def make_save_spec(tensor, suffix): 80 return saver.BaseSaverBuilder.SaveSpec(tensor, slice_spec, name + suffix) 81 82 for i in range(self._num_streams): 83 specs += [ 84 make_save_spec(bucket_boundaries[i], '_bucket_boundaries_' + str(i)) 85 ] 86 super(QuantileAccumulatorSaveable, self).__init__(self._resource_handle, 87 specs, name) 88 89 def restore(self, restored_tensors, unused_tensor_shapes): 90 bucket_boundaries = restored_tensors 91 with ops.control_dependencies([self._create_op]): 92 return quantile_resource_deserialize( 93 self._resource_handle, bucket_boundaries=bucket_boundaries) 94 95 96class QuantileAccumulator(resource.TrackableResource): 97 """SaveableObject implementation for QuantileAccumulator. 98 99 The bucket boundaries are serialized and deserialized from checkpointing. 100 """ 101 102 def __init__(self, 103 epsilon, 104 num_streams, 105 num_quantiles, 106 name=None, 107 max_elements=None): 108 self._eps = epsilon 109 self._num_streams = num_streams 110 self._num_quantiles = num_quantiles 111 super(QuantileAccumulator, self).__init__() 112 113 with ops.name_scope(name, 'QuantileAccumulator') as name: 114 self._name = name 115 self._resource_handle = self._create_resource() 116 self._init_op = self._initialize() 117 is_initialized_op = self.is_initialized() 118 resources.register_resource(self.resource_handle, self._init_op, 119 is_initialized_op) 120 self._saveable = QuantileAccumulatorSaveable( 121 self.resource_handle, self._init_op, self._num_streams, 122 self.resource_handle.name) 123 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) 124 125 def _create_resource(self): 126 return quantile_resource_handle_op( 127 container='', shared_name=self._name, name=self._name) 128 129 def _initialize(self): 130 return create_quantile_stream_resource(self.resource_handle, self._eps, 131 self._num_streams) 132 133 @property 134 def initializer(self): 135 if self._init_op is None: 136 self._init_op = self._initialize() 137 return self._init_op 138 139 def is_initialized(self): 140 return is_quantile_resource_initialized(self.resource_handle) 141 142 @property 143 def saveable(self): 144 return self._saveable 145 146 def _gather_saveables_for_checkpoint(self): 147 return {'quantile_accumulator', self._saveable} 148 149 def add_summaries(self, float_columns, example_weights): 150 summaries = make_quantile_summaries(float_columns, example_weights, 151 self._eps) 152 summary_op = quantile_add_summaries(self.resource_handle, summaries) 153 return summary_op 154 155 def flush(self): 156 return quantile_flush(self.resource_handle, self._num_quantiles) 157 158 def get_bucket_boundaries(self): 159 return get_bucket_boundaries(self.resource_handle, self._num_streams) 160 161 162class _TreeEnsembleSavable(saver.BaseSaverBuilder.SaveableObject): 163 """SaveableObject implementation for TreeEnsemble.""" 164 165 def __init__(self, resource_handle, create_op, name): 166 """Creates a _TreeEnsembleSavable object. 167 168 Args: 169 resource_handle: handle to the decision tree ensemble variable. 170 create_op: the op to initialize the variable. 171 name: the name to save the tree ensemble variable under. 172 """ 173 stamp_token, serialized = ( 174 gen_boosted_trees_ops.boosted_trees_serialize_ensemble(resource_handle)) 175 # slice_spec is useful for saving a slice from a variable. 176 # It's not meaningful the tree ensemble variable. So we just pass an empty 177 # value. 178 slice_spec = '' 179 specs = [ 180 saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec, 181 name + '_stamp'), 182 saver.BaseSaverBuilder.SaveSpec(serialized, slice_spec, 183 name + '_serialized'), 184 ] 185 super(_TreeEnsembleSavable, self).__init__(resource_handle, specs, name) 186 self._resource_handle = resource_handle 187 self._create_op = create_op 188 189 def restore(self, restored_tensors, unused_restored_shapes): 190 """Restores the associated tree ensemble from 'restored_tensors'. 191 192 Args: 193 restored_tensors: the tensors that were loaded from a checkpoint. 194 unused_restored_shapes: the shapes this object should conform to after 195 restore. Not meaningful for trees. 196 197 Returns: 198 The operation that restores the state of the tree ensemble variable. 199 """ 200 with ops.control_dependencies([self._create_op]): 201 return gen_boosted_trees_ops.boosted_trees_deserialize_ensemble( 202 self._resource_handle, 203 stamp_token=restored_tensors[0], 204 tree_ensemble_serialized=restored_tensors[1]) 205 206 207class TreeEnsemble(resource.TrackableResource): 208 """Creates TreeEnsemble resource.""" 209 210 def __init__(self, name, stamp_token=0, is_local=False, serialized_proto=''): 211 self._stamp_token = stamp_token 212 self._serialized_proto = serialized_proto 213 self._is_local = is_local 214 with ops.name_scope(name, 'TreeEnsemble') as name: 215 self._name = name 216 self._resource_handle = self._create_resource() 217 self._init_op = self._initialize() 218 is_initialized_op = self.is_initialized() 219 # Adds the variable to the savable list. 220 if not is_local: 221 self._saveable = _TreeEnsembleSavable( 222 self.resource_handle, self.initializer, self.resource_handle.name) 223 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) 224 resources.register_resource( 225 self.resource_handle, 226 self.initializer, 227 is_initialized_op, 228 is_shared=not is_local) 229 230 def _create_resource(self): 231 return gen_boosted_trees_ops.boosted_trees_ensemble_resource_handle_op( 232 container='', shared_name=self._name, name=self._name) 233 234 def _initialize(self): 235 return gen_boosted_trees_ops.boosted_trees_create_ensemble( 236 self.resource_handle, 237 self._stamp_token, 238 tree_ensemble_serialized=self._serialized_proto) 239 240 @property 241 def initializer(self): 242 if self._init_op is None: 243 self._init_op = self._initialize() 244 return self._init_op 245 246 def is_initialized(self): 247 return gen_boosted_trees_ops.is_boosted_trees_ensemble_initialized( 248 self.resource_handle) 249 250 def _gather_saveables_for_checkpoint(self): 251 if not self._is_local: 252 return {'tree_ensemble': self._saveable} 253 254 def get_stamp_token(self): 255 """Returns the current stamp token of the resource.""" 256 stamp_token, _, _, _, _ = ( 257 gen_boosted_trees_ops.boosted_trees_get_ensemble_states( 258 self.resource_handle)) 259 return stamp_token 260 261 def get_states(self): 262 """Returns states of the tree ensemble. 263 264 Returns: 265 stamp_token, num_trees, num_finalized_trees, num_attempted_layers and 266 range of the nodes in the latest layer. 267 """ 268 (stamp_token, num_trees, num_finalized_trees, num_attempted_layers, 269 nodes_range) = ( 270 gen_boosted_trees_ops.boosted_trees_get_ensemble_states( 271 self.resource_handle)) 272 # Use identity to give names. 273 return (array_ops.identity(stamp_token, name='stamp_token'), 274 array_ops.identity(num_trees, name='num_trees'), 275 array_ops.identity(num_finalized_trees, name='num_finalized_trees'), 276 array_ops.identity( 277 num_attempted_layers, name='num_attempted_layers'), 278 array_ops.identity(nodes_range, name='last_layer_nodes_range')) 279 280 def serialize(self): 281 """Serializes the ensemble into proto and returns the serialized proto. 282 283 Returns: 284 stamp_token: int64 scalar Tensor to denote the stamp of the resource. 285 serialized_proto: string scalar Tensor of the serialized proto. 286 """ 287 return gen_boosted_trees_ops.boosted_trees_serialize_ensemble( 288 self.resource_handle) 289 290 def deserialize(self, stamp_token, serialized_proto): 291 """Deserialize the input proto and resets the ensemble from it. 292 293 Args: 294 stamp_token: int64 scalar Tensor to denote the stamp of the resource. 295 serialized_proto: string scalar Tensor of the serialized proto. 296 297 Returns: 298 Operation (for dependencies). 299 """ 300 return gen_boosted_trees_ops.boosted_trees_deserialize_ensemble( 301 self.resource_handle, stamp_token, serialized_proto) 302