xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/boosted_trees_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops for boosted_trees."""
16from tensorflow.python.framework import ops
17from tensorflow.python.ops import array_ops
18from tensorflow.python.ops import gen_boosted_trees_ops
19from tensorflow.python.ops import resources
20
21# Re-exporting ops used by other modules.
22# pylint: disable=unused-import
23from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_aggregate_stats
24from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_bucketize
25from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_feature_split as calculate_best_feature_split
26from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_feature_split_v2 as calculate_best_feature_split_v2
27from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_gains_per_feature as calculate_best_gains_per_feature
28from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_center_bias as center_bias
29from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_create_quantile_stream_resource as create_quantile_stream_resource
30from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_example_debug_outputs as example_debug_outputs
31from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_quantile_summaries as make_quantile_summaries
32from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_stats_summary as make_stats_summary
33from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_predict as predict
34from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_add_summaries as quantile_add_summaries
35from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_deserialize as quantile_resource_deserialize
36from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_flush as quantile_flush
37from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_get_bucket_boundaries as get_bucket_boundaries
38from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_handle_op as quantile_resource_handle_op
39from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_sparse_aggregate_stats
40from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_sparse_calculate_best_feature_split as sparse_calculate_best_feature_split
41from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_training_predict as training_predict
42from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_update_ensemble as update_ensemble
43from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_update_ensemble_v2 as update_ensemble_v2
44from tensorflow.python.ops.gen_boosted_trees_ops import is_boosted_trees_quantile_stream_resource_initialized as is_quantile_resource_initialized
45# pylint: enable=unused-import
46
47from tensorflow.python.trackable import resource
48from tensorflow.python.training import saver
49
50
51class PruningMode:
52  """Class for working with Pruning modes."""
53  NO_PRUNING, PRE_PRUNING, POST_PRUNING = range(0, 3)
54
55  _map = {'none': NO_PRUNING, 'pre': PRE_PRUNING, 'post': POST_PRUNING}
56
57  @classmethod
58  def from_str(cls, mode):
59    if mode in cls._map:
60      return cls._map[mode]
61    else:
62      raise ValueError(
63          'pruning_mode mode must be one of: {}. Found: {}'.format(', '.join(
64              sorted(cls._map)), mode))
65
66
67class QuantileAccumulatorSaveable(saver.BaseSaverBuilder.SaveableObject):
68  """SaveableObject implementation for QuantileAccumulator."""
69
70  def __init__(self, resource_handle, create_op, num_streams, name):
71    self._resource_handle = resource_handle
72    self._num_streams = num_streams
73    self._create_op = create_op
74    bucket_boundaries = get_bucket_boundaries(self._resource_handle,
75                                              self._num_streams)
76    slice_spec = ''
77    specs = []
78
79    def make_save_spec(tensor, suffix):
80      return saver.BaseSaverBuilder.SaveSpec(tensor, slice_spec, name + suffix)
81
82    for i in range(self._num_streams):
83      specs += [
84          make_save_spec(bucket_boundaries[i], '_bucket_boundaries_' + str(i))
85      ]
86    super(QuantileAccumulatorSaveable, self).__init__(self._resource_handle,
87                                                      specs, name)
88
89  def restore(self, restored_tensors, unused_tensor_shapes):
90    bucket_boundaries = restored_tensors
91    with ops.control_dependencies([self._create_op]):
92      return quantile_resource_deserialize(
93          self._resource_handle, bucket_boundaries=bucket_boundaries)
94
95
96class QuantileAccumulator(resource.TrackableResource):
97  """SaveableObject implementation for QuantileAccumulator.
98
99     The bucket boundaries are serialized and deserialized from checkpointing.
100  """
101
102  def __init__(self,
103               epsilon,
104               num_streams,
105               num_quantiles,
106               name=None,
107               max_elements=None):
108    self._eps = epsilon
109    self._num_streams = num_streams
110    self._num_quantiles = num_quantiles
111    super(QuantileAccumulator, self).__init__()
112
113    with ops.name_scope(name, 'QuantileAccumulator') as name:
114      self._name = name
115      self._resource_handle = self._create_resource()
116      self._init_op = self._initialize()
117      is_initialized_op = self.is_initialized()
118    resources.register_resource(self.resource_handle, self._init_op,
119                                is_initialized_op)
120    self._saveable = QuantileAccumulatorSaveable(
121        self.resource_handle, self._init_op, self._num_streams,
122        self.resource_handle.name)
123    ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable)
124
125  def _create_resource(self):
126    return quantile_resource_handle_op(
127        container='', shared_name=self._name, name=self._name)
128
129  def _initialize(self):
130    return create_quantile_stream_resource(self.resource_handle, self._eps,
131                                           self._num_streams)
132
133  @property
134  def initializer(self):
135    if self._init_op is None:
136      self._init_op = self._initialize()
137    return self._init_op
138
139  def is_initialized(self):
140    return is_quantile_resource_initialized(self.resource_handle)
141
142  @property
143  def saveable(self):
144    return self._saveable
145
146  def _gather_saveables_for_checkpoint(self):
147    return {'quantile_accumulator', self._saveable}
148
149  def add_summaries(self, float_columns, example_weights):
150    summaries = make_quantile_summaries(float_columns, example_weights,
151                                        self._eps)
152    summary_op = quantile_add_summaries(self.resource_handle, summaries)
153    return summary_op
154
155  def flush(self):
156    return quantile_flush(self.resource_handle, self._num_quantiles)
157
158  def get_bucket_boundaries(self):
159    return get_bucket_boundaries(self.resource_handle, self._num_streams)
160
161
162class _TreeEnsembleSavable(saver.BaseSaverBuilder.SaveableObject):
163  """SaveableObject implementation for TreeEnsemble."""
164
165  def __init__(self, resource_handle, create_op, name):
166    """Creates a _TreeEnsembleSavable object.
167
168    Args:
169      resource_handle: handle to the decision tree ensemble variable.
170      create_op: the op to initialize the variable.
171      name: the name to save the tree ensemble variable under.
172    """
173    stamp_token, serialized = (
174        gen_boosted_trees_ops.boosted_trees_serialize_ensemble(resource_handle))
175    # slice_spec is useful for saving a slice from a variable.
176    # It's not meaningful the tree ensemble variable. So we just pass an empty
177    # value.
178    slice_spec = ''
179    specs = [
180        saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec,
181                                        name + '_stamp'),
182        saver.BaseSaverBuilder.SaveSpec(serialized, slice_spec,
183                                        name + '_serialized'),
184    ]
185    super(_TreeEnsembleSavable, self).__init__(resource_handle, specs, name)
186    self._resource_handle = resource_handle
187    self._create_op = create_op
188
189  def restore(self, restored_tensors, unused_restored_shapes):
190    """Restores the associated tree ensemble from 'restored_tensors'.
191
192    Args:
193      restored_tensors: the tensors that were loaded from a checkpoint.
194      unused_restored_shapes: the shapes this object should conform to after
195        restore. Not meaningful for trees.
196
197    Returns:
198      The operation that restores the state of the tree ensemble variable.
199    """
200    with ops.control_dependencies([self._create_op]):
201      return gen_boosted_trees_ops.boosted_trees_deserialize_ensemble(
202          self._resource_handle,
203          stamp_token=restored_tensors[0],
204          tree_ensemble_serialized=restored_tensors[1])
205
206
207class TreeEnsemble(resource.TrackableResource):
208  """Creates TreeEnsemble resource."""
209
210  def __init__(self, name, stamp_token=0, is_local=False, serialized_proto=''):
211    self._stamp_token = stamp_token
212    self._serialized_proto = serialized_proto
213    self._is_local = is_local
214    with ops.name_scope(name, 'TreeEnsemble') as name:
215      self._name = name
216      self._resource_handle = self._create_resource()
217      self._init_op = self._initialize()
218      is_initialized_op = self.is_initialized()
219      # Adds the variable to the savable list.
220      if not is_local:
221        self._saveable = _TreeEnsembleSavable(
222            self.resource_handle, self.initializer, self.resource_handle.name)
223        ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable)
224      resources.register_resource(
225          self.resource_handle,
226          self.initializer,
227          is_initialized_op,
228          is_shared=not is_local)
229
230  def _create_resource(self):
231    return gen_boosted_trees_ops.boosted_trees_ensemble_resource_handle_op(
232        container='', shared_name=self._name, name=self._name)
233
234  def _initialize(self):
235    return gen_boosted_trees_ops.boosted_trees_create_ensemble(
236        self.resource_handle,
237        self._stamp_token,
238        tree_ensemble_serialized=self._serialized_proto)
239
240  @property
241  def initializer(self):
242    if self._init_op is None:
243      self._init_op = self._initialize()
244    return self._init_op
245
246  def is_initialized(self):
247    return gen_boosted_trees_ops.is_boosted_trees_ensemble_initialized(
248        self.resource_handle)
249
250  def _gather_saveables_for_checkpoint(self):
251    if not self._is_local:
252      return {'tree_ensemble': self._saveable}
253
254  def get_stamp_token(self):
255    """Returns the current stamp token of the resource."""
256    stamp_token, _, _, _, _ = (
257        gen_boosted_trees_ops.boosted_trees_get_ensemble_states(
258            self.resource_handle))
259    return stamp_token
260
261  def get_states(self):
262    """Returns states of the tree ensemble.
263
264    Returns:
265      stamp_token, num_trees, num_finalized_trees, num_attempted_layers and
266      range of the nodes in the latest layer.
267    """
268    (stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
269     nodes_range) = (
270         gen_boosted_trees_ops.boosted_trees_get_ensemble_states(
271             self.resource_handle))
272    # Use identity to give names.
273    return (array_ops.identity(stamp_token, name='stamp_token'),
274            array_ops.identity(num_trees, name='num_trees'),
275            array_ops.identity(num_finalized_trees, name='num_finalized_trees'),
276            array_ops.identity(
277                num_attempted_layers, name='num_attempted_layers'),
278            array_ops.identity(nodes_range, name='last_layer_nodes_range'))
279
280  def serialize(self):
281    """Serializes the ensemble into proto and returns the serialized proto.
282
283    Returns:
284      stamp_token: int64 scalar Tensor to denote the stamp of the resource.
285      serialized_proto: string scalar Tensor of the serialized proto.
286    """
287    return gen_boosted_trees_ops.boosted_trees_serialize_ensemble(
288        self.resource_handle)
289
290  def deserialize(self, stamp_token, serialized_proto):
291    """Deserialize the input proto and resets the ensemble from it.
292
293    Args:
294      stamp_token: int64 scalar Tensor to denote the stamp of the resource.
295      serialized_proto: string scalar Tensor of the serialized proto.
296
297    Returns:
298      Operation (for dependencies).
299    """
300    return gen_boosted_trees_ops.boosted_trees_deserialize_ensemble(
301        self.resource_handle, stamp_token, serialized_proto)
302