xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/aggregate_ops.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_
17 #define TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_
18 
19 #include <numeric>
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/tensor_types.h"
23 
24 namespace tensorflow {
25 namespace functor {
26 
27 // Functor definitions for Aggregate ops, must be compilable by nvcc.
28 template <typename Device, typename T>
29 struct Add2Functor {
30   void operator()(const Device& d, typename TTypes<T>::Flat out,
31                   typename TTypes<T>::ConstFlat in1,
32                   typename TTypes<T>::ConstFlat in2);
33 };
34 
35 template <typename Device, typename T>
36 struct Add2EigenImpl {
ComputeAdd2EigenImpl37   static void Compute(const Device& d, typename TTypes<T>::Flat out,
38                       typename TTypes<T>::ConstFlat in1,
39                       typename TTypes<T>::ConstFlat in2) {
40     out.device(d) = in1 + in2;
41   }
42 };
43 
44 template <typename Device, typename T>
45 struct Add3Functor {
46   void operator()(const Device& d, typename TTypes<T>::Flat out,
47                   typename TTypes<T>::ConstFlat in1,
48                   typename TTypes<T>::ConstFlat in2,
49                   typename TTypes<T>::ConstFlat in3);
50 };
51 
52 template <typename Device, typename T>
53 struct Add3EigenImpl {
ComputeAdd3EigenImpl54   static void Compute(const Device& d, typename TTypes<T>::Flat out,
55                       typename TTypes<T>::ConstFlat in1,
56                       typename TTypes<T>::ConstFlat in2,
57                       typename TTypes<T>::ConstFlat in3) {
58     out.device(d) = in1 + in2 + in3;
59   }
60 };
61 
62 template <typename Device, typename T>
63 struct Add4Functor {
64   void operator()(const Device& d, typename TTypes<T>::Flat out,
65                   typename TTypes<T>::ConstFlat in1,
66                   typename TTypes<T>::ConstFlat in2,
67                   typename TTypes<T>::ConstFlat in3,
68                   typename TTypes<T>::ConstFlat in4);
69 };
70 
71 template <typename Device, typename T>
72 struct Add4EigenImpl {
ComputeAdd4EigenImpl73   static void Compute(const Device& d, typename TTypes<T>::Flat out,
74                       typename TTypes<T>::ConstFlat in1,
75                       typename TTypes<T>::ConstFlat in2,
76                       typename TTypes<T>::ConstFlat in3,
77                       typename TTypes<T>::ConstFlat in4) {
78     out.device(d) = in1 + in2 + in3 + in4;
79   }
80 };
81 
82 template <typename Device, typename T>
83 struct Add5Functor {
84   void operator()(const Device& d, typename TTypes<T>::Flat out,
85                   typename TTypes<T>::ConstFlat in1,
86                   typename TTypes<T>::ConstFlat in2,
87                   typename TTypes<T>::ConstFlat in3,
88                   typename TTypes<T>::ConstFlat in4,
89                   typename TTypes<T>::ConstFlat in5);
90 };
91 
92 template <typename Device, typename T>
93 struct Add5EigenImpl {
ComputeAdd5EigenImpl94   static void Compute(const Device& d, typename TTypes<T>::Flat out,
95                       typename TTypes<T>::ConstFlat in1,
96                       typename TTypes<T>::ConstFlat in2,
97                       typename TTypes<T>::ConstFlat in3,
98                       typename TTypes<T>::ConstFlat in4,
99                       typename TTypes<T>::ConstFlat in5) {
100     out.device(d) = in1 + in2 + in3 + in4 + in5;
101   }
102 };
103 
104 template <typename Device, typename T>
105 struct Add6Functor {
106   void operator()(const Device& d, typename TTypes<T>::Flat out,
107                   typename TTypes<T>::ConstFlat in1,
108                   typename TTypes<T>::ConstFlat in2,
109                   typename TTypes<T>::ConstFlat in3,
110                   typename TTypes<T>::ConstFlat in4,
111                   typename TTypes<T>::ConstFlat in5,
112                   typename TTypes<T>::ConstFlat in6);
113 };
114 
115 template <typename Device, typename T>
116 struct Add6EigenImpl {
ComputeAdd6EigenImpl117   static void Compute(const Device& d, typename TTypes<T>::Flat out,
118                       typename TTypes<T>::ConstFlat in1,
119                       typename TTypes<T>::ConstFlat in2,
120                       typename TTypes<T>::ConstFlat in3,
121                       typename TTypes<T>::ConstFlat in4,
122                       typename TTypes<T>::ConstFlat in5,
123                       typename TTypes<T>::ConstFlat in6) {
124     out.device(d) = in1 + in2 + in3 + in4 + in5 + in6;
125   }
126 };
127 
128 template <typename Device, typename T>
129 struct Add7Functor {
130   void operator()(const Device& d, typename TTypes<T>::Flat out,
131                   typename TTypes<T>::ConstFlat in1,
132                   typename TTypes<T>::ConstFlat in2,
133                   typename TTypes<T>::ConstFlat in3,
134                   typename TTypes<T>::ConstFlat in4,
135                   typename TTypes<T>::ConstFlat in5,
136                   typename TTypes<T>::ConstFlat in6,
137                   typename TTypes<T>::ConstFlat in7);
138 };
139 
140 template <typename Device, typename T>
141 struct Add7EigenImpl {
ComputeAdd7EigenImpl142   static void Compute(const Device& d, typename TTypes<T>::Flat out,
143                       typename TTypes<T>::ConstFlat in1,
144                       typename TTypes<T>::ConstFlat in2,
145                       typename TTypes<T>::ConstFlat in3,
146                       typename TTypes<T>::ConstFlat in4,
147                       typename TTypes<T>::ConstFlat in5,
148                       typename TTypes<T>::ConstFlat in6,
149                       typename TTypes<T>::ConstFlat in7) {
150     out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7;
151   }
152 };
153 
154 template <typename Device, typename T>
155 struct Add8Functor {
156   void operator()(
157       const Device& d, typename TTypes<T>::Flat out,
158       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
159       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
160       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
161       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8);
162 };
163 
164 template <typename Device, typename T>
165 struct Add8EigenImpl {
ComputeAdd8EigenImpl166   static void Compute(
167       const Device& d, typename TTypes<T>::Flat out,
168       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
169       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
170       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
171       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
172     out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8;
173   }
174 };
175 
176 // Add8p is like Add8 except the underlying implementation should +=
177 // rather than assign to the output.
178 template <typename Device, typename T>
179 struct Add8pFunctor {
180   void operator()(
181       const Device& d, typename TTypes<T>::Flat out,
182       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
183       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
184       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
185       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8);
186 };
187 
188 template <typename Device, typename T>
189 struct Add8pEigenImpl {
ComputeAdd8pEigenImpl190   static void Compute(
191       const Device& d, typename TTypes<T>::Flat out,
192       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
193       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
194       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
195       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
196     out.device(d) += in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8;
197   }
198 };
199 
200 template <typename Device, typename T>
201 struct Add9Functor {
202   void operator()(
203       const Device& d, typename TTypes<T>::Flat out,
204       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
205       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
206       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
207       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
208       typename TTypes<T>::ConstFlat in9);
209 };
210 
211 template <typename Device, typename T>
212 struct Add9EigenImpl {
ComputeAdd9EigenImpl213   static void Compute(
214       const Device& d, typename TTypes<T>::Flat out,
215       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
216       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
217       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
218       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
219       typename TTypes<T>::ConstFlat in9) {
220     out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8 + in9;
221   }
222 };
223 }  // namespace functor
224 }  // namespace tensorflow
225 
226 #endif  // TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_
227