1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_ 17 #define TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_ 18 19 #include <numeric> 20 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "tensorflow/core/framework/tensor_types.h" 23 24 namespace tensorflow { 25 namespace functor { 26 27 // Functor definitions for Aggregate ops, must be compilable by nvcc. 28 template <typename Device, typename T> 29 struct Add2Functor { 30 void operator()(const Device& d, typename TTypes<T>::Flat out, 31 typename TTypes<T>::ConstFlat in1, 32 typename TTypes<T>::ConstFlat in2); 33 }; 34 35 template <typename Device, typename T> 36 struct Add2EigenImpl { ComputeAdd2EigenImpl37 static void Compute(const Device& d, typename TTypes<T>::Flat out, 38 typename TTypes<T>::ConstFlat in1, 39 typename TTypes<T>::ConstFlat in2) { 40 out.device(d) = in1 + in2; 41 } 42 }; 43 44 template <typename Device, typename T> 45 struct Add3Functor { 46 void operator()(const Device& d, typename TTypes<T>::Flat out, 47 typename TTypes<T>::ConstFlat in1, 48 typename TTypes<T>::ConstFlat in2, 49 typename TTypes<T>::ConstFlat in3); 50 }; 51 52 template <typename Device, typename T> 53 struct Add3EigenImpl { ComputeAdd3EigenImpl54 static void Compute(const Device& d, typename TTypes<T>::Flat out, 55 typename TTypes<T>::ConstFlat in1, 56 typename TTypes<T>::ConstFlat in2, 57 typename TTypes<T>::ConstFlat in3) { 58 out.device(d) = in1 + in2 + in3; 59 } 60 }; 61 62 template <typename Device, typename T> 63 struct Add4Functor { 64 void operator()(const Device& d, typename TTypes<T>::Flat out, 65 typename TTypes<T>::ConstFlat in1, 66 typename TTypes<T>::ConstFlat in2, 67 typename TTypes<T>::ConstFlat in3, 68 typename TTypes<T>::ConstFlat in4); 69 }; 70 71 template <typename Device, typename T> 72 struct Add4EigenImpl { ComputeAdd4EigenImpl73 static void Compute(const Device& d, typename TTypes<T>::Flat out, 74 typename TTypes<T>::ConstFlat in1, 75 typename TTypes<T>::ConstFlat in2, 76 typename TTypes<T>::ConstFlat in3, 77 typename TTypes<T>::ConstFlat in4) { 78 out.device(d) = in1 + in2 + in3 + in4; 79 } 80 }; 81 82 template <typename Device, typename T> 83 struct Add5Functor { 84 void operator()(const Device& d, typename TTypes<T>::Flat out, 85 typename TTypes<T>::ConstFlat in1, 86 typename TTypes<T>::ConstFlat in2, 87 typename TTypes<T>::ConstFlat in3, 88 typename TTypes<T>::ConstFlat in4, 89 typename TTypes<T>::ConstFlat in5); 90 }; 91 92 template <typename Device, typename T> 93 struct Add5EigenImpl { ComputeAdd5EigenImpl94 static void Compute(const Device& d, typename TTypes<T>::Flat out, 95 typename TTypes<T>::ConstFlat in1, 96 typename TTypes<T>::ConstFlat in2, 97 typename TTypes<T>::ConstFlat in3, 98 typename TTypes<T>::ConstFlat in4, 99 typename TTypes<T>::ConstFlat in5) { 100 out.device(d) = in1 + in2 + in3 + in4 + in5; 101 } 102 }; 103 104 template <typename Device, typename T> 105 struct Add6Functor { 106 void operator()(const Device& d, typename TTypes<T>::Flat out, 107 typename TTypes<T>::ConstFlat in1, 108 typename TTypes<T>::ConstFlat in2, 109 typename TTypes<T>::ConstFlat in3, 110 typename TTypes<T>::ConstFlat in4, 111 typename TTypes<T>::ConstFlat in5, 112 typename TTypes<T>::ConstFlat in6); 113 }; 114 115 template <typename Device, typename T> 116 struct Add6EigenImpl { ComputeAdd6EigenImpl117 static void Compute(const Device& d, typename TTypes<T>::Flat out, 118 typename TTypes<T>::ConstFlat in1, 119 typename TTypes<T>::ConstFlat in2, 120 typename TTypes<T>::ConstFlat in3, 121 typename TTypes<T>::ConstFlat in4, 122 typename TTypes<T>::ConstFlat in5, 123 typename TTypes<T>::ConstFlat in6) { 124 out.device(d) = in1 + in2 + in3 + in4 + in5 + in6; 125 } 126 }; 127 128 template <typename Device, typename T> 129 struct Add7Functor { 130 void operator()(const Device& d, typename TTypes<T>::Flat out, 131 typename TTypes<T>::ConstFlat in1, 132 typename TTypes<T>::ConstFlat in2, 133 typename TTypes<T>::ConstFlat in3, 134 typename TTypes<T>::ConstFlat in4, 135 typename TTypes<T>::ConstFlat in5, 136 typename TTypes<T>::ConstFlat in6, 137 typename TTypes<T>::ConstFlat in7); 138 }; 139 140 template <typename Device, typename T> 141 struct Add7EigenImpl { ComputeAdd7EigenImpl142 static void Compute(const Device& d, typename TTypes<T>::Flat out, 143 typename TTypes<T>::ConstFlat in1, 144 typename TTypes<T>::ConstFlat in2, 145 typename TTypes<T>::ConstFlat in3, 146 typename TTypes<T>::ConstFlat in4, 147 typename TTypes<T>::ConstFlat in5, 148 typename TTypes<T>::ConstFlat in6, 149 typename TTypes<T>::ConstFlat in7) { 150 out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7; 151 } 152 }; 153 154 template <typename Device, typename T> 155 struct Add8Functor { 156 void operator()( 157 const Device& d, typename TTypes<T>::Flat out, 158 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 159 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 160 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 161 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8); 162 }; 163 164 template <typename Device, typename T> 165 struct Add8EigenImpl { ComputeAdd8EigenImpl166 static void Compute( 167 const Device& d, typename TTypes<T>::Flat out, 168 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 169 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 170 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 171 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) { 172 out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8; 173 } 174 }; 175 176 // Add8p is like Add8 except the underlying implementation should += 177 // rather than assign to the output. 178 template <typename Device, typename T> 179 struct Add8pFunctor { 180 void operator()( 181 const Device& d, typename TTypes<T>::Flat out, 182 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 183 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 184 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 185 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8); 186 }; 187 188 template <typename Device, typename T> 189 struct Add8pEigenImpl { ComputeAdd8pEigenImpl190 static void Compute( 191 const Device& d, typename TTypes<T>::Flat out, 192 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 193 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 194 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 195 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) { 196 out.device(d) += in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8; 197 } 198 }; 199 200 template <typename Device, typename T> 201 struct Add9Functor { 202 void operator()( 203 const Device& d, typename TTypes<T>::Flat out, 204 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 205 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 206 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 207 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8, 208 typename TTypes<T>::ConstFlat in9); 209 }; 210 211 template <typename Device, typename T> 212 struct Add9EigenImpl { ComputeAdd9EigenImpl213 static void Compute( 214 const Device& d, typename TTypes<T>::Flat out, 215 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 216 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 217 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 218 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8, 219 typename TTypes<T>::ConstFlat in9) { 220 out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8 + in9; 221 } 222 }; 223 } // namespace functor 224 } // namespace tensorflow 225 226 #endif // TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_ 227