xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/segment_reduction_ops.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
17 #define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/framework/tensor_types.h"
23 
24 namespace tensorflow {
25 
26 class OpKernelContext;
27 
28 bool UseDeterministicSegmentReductions();
29 bool DisableSegmentReductionOpDeterminismExceptions();
30 
31 // Type of SparseSegmentReduction operation to perform gradient of.
32 enum class SparseSegmentReductionOperation { kSum, kMean, kSqrtN };
33 
34 namespace functor {
35 
36 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
37 
38 // Note that we define this ourselves to avoid a dependency on gpuprim.
39 struct Sum {
40   template <typename T>
operatorSum41   __host__ __device__ T operator()(const T& a, const T& b) const {
42     return a + b;
43   }
44 };
45 
46 struct Prod {
47   template <typename T>
operatorProd48   __host__ __device__ T operator()(const T& a, const T& b) const {
49     return a * b;
50   }
51 };
52 
53 // Note that we don't use gpuprim::Min/Max because they use operator<, which is
54 // not implemented for AlignedVector types.
55 struct Min {
56   template <typename T>
operatorMin57   __host__ __device__ T operator()(const T& a, const T& b) const {
58     return min(a, b);
59   }
60 };
61 
62 struct Max {
63   template <typename T>
operatorMax64   __host__ __device__ T operator()(const T& a, const T& b) const {
65     return max(a, b);
66   }
67 };
68 
69 template <typename ReduceOp, typename T>
70 struct ReduceOpIsAssociative {};
71 template <typename T>
72 struct ReduceOpIsAssociative<functor::Sum, T> : std::is_integral<T> {};
73 template <typename T>
74 struct ReduceOpIsAssociative<functor::Prod, T> : std::is_integral<T> {};
75 template <typename T>
76 struct ReduceOpIsAssociative<functor::Max, T> : std::true_type {};
77 template <typename T>
78 struct ReduceOpIsAssociative<functor::Min, T> : std::true_type {};
79 
80 typedef Eigen::GpuDevice GPUDevice;
81 // Functor for SegmentReductionGPUOp.
82 // output_rows: the number of output segments (unique segment ids in
83 //                'segment_ids').
84 // segment_ids_shape: shape of 'segment_ids' tensor.
85 // segment_ids: unsorted map from input to output segment ids at which to
86 //                perform segment sum operation.
87 // data_size: size of input data tensor.
88 // data: input data tensor.
89 // output: output reshaped to {output_rows, output.size/output_rows}
90 template <typename T, typename Index, typename InitialValueF,
91           typename EmptySegmentValueF, typename ReductionF>
92 struct SegmentReductionFunctor {
93   void operator()(OpKernelContext* ctx, const GPUDevice& d,
94                   const Index output_rows, const TensorShape& segment_ids_shape,
95                   bool is_mean, typename TTypes<Index>::ConstFlat segment_ids,
96                   const Index data_size, const T* data,
97                   typename TTypes<T, 2>::Tensor output);
98   static constexpr bool atomic_reduction_is_associative =
99       ReduceOpIsAssociative<ReductionF, T>::value;
100 };
101 
102 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
103 
104 template <typename Device, typename T, typename Index, typename InitialValueF,
105           typename ReductionF>
106 struct UnsortedSegmentFunctor {
107   void operator()(OpKernelContext* ctx, const TensorShape& segment_ids_shape,
108                   typename TTypes<Index>::ConstFlat segment_ids,
109                   typename TTypes<T, 2>::ConstTensor data,
110                   typename TTypes<T, 2>::Tensor output);
111 };
112 
113 // Initial value functors.
114 template <typename T>
115 struct Zero {
116   EIGEN_STRONG_INLINE T operator()() const { return T(0); }
117 };
118 
119 template <typename T>
120 struct One {
121   EIGEN_STRONG_INLINE T operator()() const { return T(1); }
122 };
123 
124 template <typename T>
125 struct Lowest {
126   EIGEN_STRONG_INLINE T operator()() const {
127     return Eigen::NumTraits<T>::lowest();
128   }
129 };
130 
131 template <typename T>
132 struct Highest {
133   EIGEN_STRONG_INLINE T operator()() const {
134     return Eigen::NumTraits<T>::highest();
135   }
136 };
137 
138 template <typename T, typename Index, typename SegmentId>
139 struct SparseSegmentReductionFunctor {
140   Status operator()(OpKernelContext* context, bool is_mean, bool is_sqrtn,
141                     T default_value, typename TTypes<T, 2>::ConstTensor input,
142                     typename TTypes<Index>::ConstVec indices,
143                     typename TTypes<SegmentId>::ConstVec segment_ids,
144                     typename TTypes<T, 2>::Tensor output);
145 };
146 
147 template <class Device, typename T, typename Index, typename SegmentId>
148 struct SparseSegmentGradFunctor {
149   void operator()(OpKernelContext* context,
150                   SparseSegmentReductionOperation operation,
151                   typename TTypes<T>::ConstMatrix input_flat,
152                   typename TTypes<Index>::ConstVec indices_vec,
153                   typename TTypes<SegmentId>::ConstVec segment_vec,
154                   typename TTypes<T>::Matrix output_flat);
155 };
156 
157 }  // namespace functor
158 }  // namespace tensorflow
159 
160 #endif  // TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
161