1 #pragma once 2 3 // DO NOT DEFINE STATIC DATA IN THIS HEADER! 4 // See Note [Do not compile initializers with AVX] 5 6 #include <ATen/cpu/vec/vec.h> 7 8 namespace at::vec { 9 10 // BFloat16 specification 11 template <typename scalar_t> struct VecScalarType { using type = scalar_t; }; 12 template <> struct VecScalarType<BFloat16> { using type = float; }; 13 template <> struct VecScalarType<Half> { using type = float; }; 14 15 // This is different from at::acc_type since we only need to specialize BFloat16 16 template <typename scalar_t> 17 using vec_scalar_t = typename VecScalarType<scalar_t>::type; 18 19 // Vector conversion between float and bfloat16/half 20 template <typename scalar_t, 21 typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> 22 inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float(const Vectorized<scalar_t>&); 23 24 template <> 25 inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float<BFloat16> (const Vectorized<BFloat16>& a) { 26 return convert_bfloat16_float(a); 27 } 28 29 template <> 30 inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float<Half> (const Vectorized<Half>& a) { 31 return convert_half_float(a); 32 } 33 34 template <typename scalar_t, 35 typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> 36 inline Vectorized<scalar_t> convert_from_float(const Vectorized<float>&, const Vectorized<float>&); 37 38 template <> 39 inline Vectorized<BFloat16> convert_from_float<BFloat16>(const Vectorized<float>& a, const Vectorized<float>& b) { 40 return convert_float_bfloat16(a, b); 41 } 42 43 template <> 44 inline Vectorized<Half> convert_from_float<Half>(const Vectorized<float>& a, const Vectorized<float>& b) { 45 return convert_float_half(a, b); 46 } 47 48 template <typename scalar_t, 49 typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> 50 inline void load_to_float(const scalar_t *data, Vectorized<float> &out1, Vectorized<float> &out2); 51 52 template <> 53 inline void load_to_float<BFloat16> (const BFloat16 *data, Vectorized<float> &out1, Vectorized<float> &out2) { 54 load_fp32_from_bf16(data, out1, out2); 55 } 56 57 template <> 58 inline void load_to_float<Half> (const Half *data, Vectorized<float> &out1, Vectorized<float> &out2) { 59 load_fp32_from_fp16(data, out1, out2); 60 } 61 62 template <typename scalar_t, 63 typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> 64 inline void load_to_float(const scalar_t *data, Vectorized<float> &out); 65 66 template <> 67 inline void load_to_float<BFloat16> (const BFloat16 *data, Vectorized<float> &out) { 68 load_fp32_from_bf16(data, out); 69 } 70 71 template <> 72 inline void load_to_float<Half> (const Half *data, Vectorized<float> &out) { 73 load_fp32_from_fp16(data, out); 74 } 75 76 // Note that we already have specialized member of Vectorized<scalar_t> for BFloat16 77 // so the following functions would run smoothly: 78 // using Vec = Vectorized<BFloat16>; 79 // Vec one = Vec(BFloat16(1)); 80 // vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N); 81 // 82 // Then why we still need to specialize "functional"? 83 // If we do specialization at Vectorized<> level, the above example would need 3 pairs of 84 // conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and "/". 85 // If we do specialization at vec::map<>() level, we have only 1 pair of conversion 86 // of bf16->fp32/fp32->bf16, for the input and output BFloat16 vector only. 87 // 88 // The following BFloat16 functionality will only do data type conversion for input 89 // and output vector (reduce functionality will only convert the final scalar back to bf16). 90 // Compared to Vectorized<> specialization, 91 // 1. better performance since we have less data type conversion; 92 // 2. less rounding error since immediate results are kept in fp32; 93 // 3. accumulation done on data type of fp32. 94 // 95 // If you plan to extend this file, please ensure adding unit tests at 96 // aten/src/ATen/test/vec_test_all_types.cpp 97 // 98 template <typename scalar_t, typename Op, 99 typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> 100 inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) { 101 using bVec = vec::Vectorized<scalar_t>; 102 using fVec = vec::Vectorized<float>; 103 if (size < bVec::size()) { 104 bVec data_bvec = bVec::loadu(data, size); 105 auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); 106 if (size > fVec::size()) { 107 data_fvec0 = fVec::set(data_fvec0, vec_fun(data_fvec0, data_fvec1), size - fVec::size()); 108 return vec_reduce_all<float>(vec_fun, data_fvec0, fVec::size()); 109 } else { 110 return vec_reduce_all<float>(vec_fun, data_fvec0, size); 111 } 112 } 113 int64_t d = bVec::size(); 114 bVec acc_bvec = bVec::loadu(data); 115 auto [acc_fvec0, acc_fvec1] = convert_to_float<scalar_t>(acc_bvec); 116 for (; d < size - (size % bVec::size()); d += bVec::size()) { 117 bVec data_bvec = bVec::loadu(data + d); 118 auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); 119 acc_fvec0 = vec_fun(acc_fvec0, data_fvec0); 120 acc_fvec1 = vec_fun(acc_fvec1, data_fvec1); 121 } 122 if (size - d > 0) { 123 bVec data_bvec = bVec::loadu(data + d, size - d); 124 auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); 125 if (size - d > fVec::size()) { 126 acc_fvec0 = vec_fun(acc_fvec0, data_fvec0); 127 acc_fvec1 = fVec::set(acc_fvec1, vec_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); 128 } else { 129 acc_fvec0 = fVec::set(acc_fvec0, vec_fun(acc_fvec0, data_fvec0), size - d); 130 } 131 } 132 acc_fvec0 = vec_fun(acc_fvec0, acc_fvec1); 133 return vec_reduce_all<float>(vec_fun, acc_fvec0); 134 } 135 136 template <typename scalar_t, typename Op1, typename Op2, 137 typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> 138 inline std::pair<float, float> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2, 139 const scalar_t* data, int64_t size) { 140 using bVec = vec::Vectorized<scalar_t>; 141 using fVec = vec::Vectorized<float>; 142 if (size < bVec::size()) { 143 bVec data_bvec = bVec::loadu(data, size); 144 auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); 145 if (size > fVec::size()) { 146 fVec acc1_fvec = fVec::set(data_fvec0, vec_fun1(data_fvec0, data_fvec1), size - fVec::size()); 147 fVec acc2_fvec = fVec::set(data_fvec0, vec_fun2(data_fvec0, data_fvec1), size - fVec::size()); 148 return std::pair<scalar_t, scalar_t>( 149 vec_reduce_all<float>(vec_fun1, acc1_fvec, fVec::size()), 150 vec_reduce_all<float>(vec_fun2, acc2_fvec, fVec::size())); 151 } else { 152 return std::pair<scalar_t, scalar_t>( 153 vec_reduce_all<float>(vec_fun1, data_fvec0, size), 154 vec_reduce_all<float>(vec_fun2, data_fvec0, size)); 155 } 156 } 157 int64_t d = bVec::size(); 158 bVec acc_bvec = bVec::loadu(data); 159 auto [acc1_fvec0, acc1_fvec1] = convert_to_float<scalar_t>(acc_bvec); 160 auto [acc2_fvec0, acc2_fvec1] = convert_to_float<scalar_t>(acc_bvec); 161 for (; d < size - (size % bVec::size()); d += bVec::size()) { 162 bVec data_bvec = bVec::loadu(data + d); 163 auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); 164 acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0); 165 acc1_fvec1 = vec_fun1(acc1_fvec1, data_fvec1); 166 acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0); 167 acc2_fvec1 = vec_fun2(acc2_fvec1, data_fvec1); 168 } 169 if (size - d > 0) { 170 bVec data_bvec = bVec::loadu(data + d, size - d); 171 auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); 172 if (size - d > fVec::size()) { 173 acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0); 174 acc1_fvec1 = fVec::set(acc1_fvec1, vec_fun1(acc1_fvec1, data_fvec1), size - d - fVec::size()); 175 acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0); 176 acc2_fvec1 = fVec::set(acc2_fvec1, vec_fun2(acc2_fvec1, data_fvec1), size - d - fVec::size()); 177 } else { 178 acc1_fvec0 = fVec::set(acc1_fvec0, vec_fun1(acc1_fvec0, data_fvec0), size - d); 179 acc2_fvec0 = fVec::set(acc2_fvec0, vec_fun2(acc2_fvec0, data_fvec0), size - d); 180 } 181 } 182 acc1_fvec0 = vec_fun1(acc1_fvec0, acc1_fvec1); 183 acc2_fvec0 = vec_fun2(acc2_fvec0, acc2_fvec1); 184 return std::pair<scalar_t, scalar_t>( 185 vec_reduce_all<float>(vec_fun1, acc1_fvec0), 186 vec_reduce_all<float>(vec_fun2, acc2_fvec0)); 187 } 188 189 template <typename scalar_t, typename MapOp, typename ReduceOp, 190 typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> 191 inline float map_reduce_all( 192 const MapOp& map_fun, 193 const ReduceOp& red_fun, 194 const scalar_t* data, 195 int64_t size) { 196 using bVec = vec::Vectorized<scalar_t>; 197 using fVec = vec::Vectorized<float>; 198 if (size < bVec::size()) { 199 bVec data_bvec = bVec::loadu(data, size); 200 auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); 201 if (size > fVec::size()) { 202 data_fvec0 = map_fun(data_fvec0); 203 data_fvec1 = map_fun(data_fvec1); 204 data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); 205 return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size()); 206 } else { 207 data_fvec0 = map_fun(data_fvec0); 208 return vec_reduce_all<float>(red_fun, data_fvec0, size); 209 } 210 } 211 int64_t d = bVec::size(); 212 bVec acc_bvec = bVec::loadu(data); 213 auto [acc_fvec0, acc_fvec1] = convert_to_float<scalar_t>(acc_bvec); 214 acc_fvec0 = map_fun(acc_fvec0); 215 acc_fvec1 = map_fun(acc_fvec1); 216 for (; d < size - (size % bVec::size()); d += bVec::size()) { 217 bVec data_bvec = bVec::loadu(data + d); 218 auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); 219 data_fvec0 = map_fun(data_fvec0); 220 data_fvec1 = map_fun(data_fvec1); 221 acc_fvec0 = red_fun(acc_fvec0, data_fvec0); 222 acc_fvec1 = red_fun(acc_fvec1, data_fvec1); 223 } 224 if (size - d > 0) { 225 bVec data_bvec = bVec::loadu(data + d, size - d); 226 auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); 227 if (size - d > fVec::size()) { 228 data_fvec0 = map_fun(data_fvec0); 229 data_fvec1 = map_fun(data_fvec1); 230 acc_fvec0 = red_fun(acc_fvec0, data_fvec0); 231 acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); 232 } else { 233 data_fvec0 = map_fun(data_fvec0); 234 acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); 235 } 236 } 237 acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); 238 return vec_reduce_all<float>(red_fun, acc_fvec0); 239 } 240 241 template <typename scalar_t, typename MapOp, typename ReduceOp, 242 typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> 243 inline float map2_reduce_all( 244 const MapOp& map_fun, 245 const ReduceOp& red_fun, 246 const scalar_t* data, 247 const scalar_t* data2, 248 int64_t size) { 249 using bVec = vec::Vectorized<scalar_t>; 250 using fVec = vec::Vectorized<float>; 251 if (size < bVec::size()) { 252 bVec data_bvec = bVec::loadu(data, size); 253 auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); 254 bVec data2_bvec = bVec::loadu(data2, size); 255 auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec); 256 if (size > fVec::size()) { 257 data_fvec0 = map_fun(data_fvec0, data2_fvec0); 258 data_fvec1 = map_fun(data_fvec1, data2_fvec1); 259 data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); 260 return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size()); 261 } else { 262 data_fvec0 = map_fun(data_fvec0, data2_fvec0); 263 return vec_reduce_all<float>(red_fun, data_fvec0, size); 264 } 265 } 266 int64_t d = bVec::size(); 267 bVec acc_bvec = bVec::loadu(data); 268 auto [acc_fvec0, acc_fvec1] = convert_to_float<scalar_t>(acc_bvec); 269 bVec acc2_bvec = bVec::loadu(data2); 270 auto [acc2_fvec0, acc2_fvec1] = convert_to_float<scalar_t>(acc2_bvec); 271 acc_fvec0 = map_fun(acc_fvec0, acc2_fvec0); 272 acc_fvec1 = map_fun(acc_fvec1, acc2_fvec1); 273 for (; d < size - (size % bVec::size()); d += bVec::size()) { 274 bVec data_bvec = bVec::loadu(data + d); 275 auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); 276 bVec data2_bvec = bVec::loadu(data2 + d); 277 auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec); 278 data_fvec0 = map_fun(data_fvec0, data2_fvec0); 279 data_fvec1 = map_fun(data_fvec1, data2_fvec1); 280 acc_fvec0 = red_fun(acc_fvec0, data_fvec0); 281 acc_fvec1 = red_fun(acc_fvec1, data_fvec1); 282 } 283 if (size - d > 0) { 284 bVec data_bvec = bVec::loadu(data + d, size - d); 285 auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); 286 bVec data2_bvec = bVec::loadu(data2 + d, size - d); 287 auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec); 288 if (size - d > fVec::size()) { 289 data_fvec0 = map_fun(data_fvec0, data2_fvec0); 290 data_fvec1 = map_fun(data_fvec1, data2_fvec1); 291 acc_fvec0 = red_fun(acc_fvec0, data_fvec0); 292 acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); 293 } else { 294 data_fvec0 = map_fun(data_fvec0, data2_fvec0); 295 acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); 296 } 297 } 298 acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); 299 return vec_reduce_all<float>(red_fun, acc_fvec0); 300 } 301 302 template <typename scalar_t, typename MapOp, typename ReduceOp, 303 typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> 304 inline float map3_reduce_all( 305 const MapOp& map_fun, 306 const ReduceOp& red_fun, 307 const scalar_t* data, 308 const scalar_t* data2, 309 const scalar_t* data3, 310 int64_t size) { 311 using bVec = vec::Vectorized<scalar_t>; 312 using fVec = vec::Vectorized<float>; 313 if (size < bVec::size()) { 314 bVec data_bvec = bVec::loadu(data, size); 315 auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); 316 bVec data2_bvec = bVec::loadu(data2, size); 317 auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec); 318 bVec data3_bvec = bVec::loadu(data3, size); 319 auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec); 320 if (size > fVec::size()) { 321 data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); 322 data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1); 323 data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); 324 return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size()); 325 } else { 326 data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); 327 return vec_reduce_all<float>(red_fun, data_fvec0, size); 328 } 329 } 330 int64_t d = bVec::size(); 331 bVec acc_bvec = bVec::loadu(data); 332 auto [acc_fvec0, acc_fvec1] = convert_to_float<scalar_t>(acc_bvec); 333 bVec acc2_bvec = bVec::loadu(data2); 334 auto [acc2_fvec0, acc2_fvec1] = convert_to_float<scalar_t>(acc2_bvec); 335 bVec acc3_bvec = bVec::loadu(data3); 336 auto [acc3_fvec0, acc3_fvec1] = convert_to_float<scalar_t>(acc3_bvec); 337 acc_fvec0 = map_fun(acc_fvec0, acc2_fvec0, acc3_fvec0); 338 acc_fvec1 = map_fun(acc_fvec1, acc2_fvec1, acc3_fvec1); 339 for (; d < size - (size % bVec::size()); d += bVec::size()) { 340 bVec data_bvec = bVec::loadu(data + d); 341 auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); 342 bVec data2_bvec = bVec::loadu(data2 + d); 343 auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec); 344 bVec data3_bvec = bVec::loadu(data3 + d); 345 auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec); 346 data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); 347 data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1); 348 acc_fvec0 = red_fun(acc_fvec0, data_fvec0); 349 acc_fvec1 = red_fun(acc_fvec1, data_fvec1); 350 } 351 if (size - d > 0) { 352 bVec data_bvec = bVec::loadu(data + d, size - d); 353 auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); 354 bVec data2_bvec = bVec::loadu(data2 + d, size - d); 355 auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec); 356 bVec data3_bvec = bVec::loadu(data3 + d, size - d); 357 auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec); 358 if (size - d > fVec::size()) { 359 data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); 360 data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1); 361 acc_fvec0 = red_fun(acc_fvec0, data_fvec0); 362 acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); 363 } else { 364 data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); 365 acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); 366 } 367 } 368 acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); 369 return vec_reduce_all<float>(red_fun, acc_fvec0); 370 } 371 372 template <typename scalar_t, typename Op, 373 typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> 374 inline void map( 375 const Op& vec_fun, 376 scalar_t* output_data, 377 const scalar_t* input_data, 378 int64_t size) { 379 using bVec = vec::Vectorized<scalar_t>; 380 using fVec = vec::Vectorized<float>; 381 int64_t d = 0; 382 for (; d < size - (size % bVec::size()); d += bVec::size()) { 383 bVec data_bvec = bVec::loadu(input_data + d); 384 auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); 385 fVec output_fvec0 = vec_fun(data_fvec0); 386 fVec output_fvec1 = vec_fun(data_fvec1); 387 bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1); 388 output_bvec.store(output_data + d); 389 } 390 if (size - d > 0) { 391 bVec data_bvec = bVec::loadu(input_data + d, size - d); 392 auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); 393 fVec output_fvec0 = vec_fun(data_fvec0); 394 fVec output_fvec1 = vec_fun(data_fvec1); 395 bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1); 396 output_bvec.store(output_data + d, size - d); 397 } 398 } 399 400 template <typename scalar_t, typename Op, 401 typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> 402 inline void map( 403 const Op& vec_fun, 404 scalar_t* output_data, 405 const float* input_data, 406 int64_t size) { 407 using bVec = vec::Vectorized<scalar_t>; 408 using fVec = vec::Vectorized<float>; 409 int64_t d = 0; 410 for (; d < size - (size % bVec::size()); d += bVec::size()) { 411 fVec data_fvec0 = fVec::loadu(input_data + d); 412 fVec data_fvec1 = fVec::loadu(input_data + d + fVec::size()); 413 fVec output_fvec0 = vec_fun(data_fvec0); 414 fVec output_fvec1 = vec_fun(data_fvec1); 415 bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1); 416 output_bvec.store(output_data + d); 417 } 418 if (size - d > 0) { 419 fVec data_fvec0, data_fvec1; 420 if (size - d > fVec::size()) { 421 data_fvec0 = fVec::loadu(input_data + d); 422 data_fvec1 = fVec::loadu(input_data + d + fVec::size(), size - d - fVec::size()); 423 } else { 424 // choose to align with behaviour of bVec::loadu(ptr, size), 425 // which leaves data_fvec1 uninitialized 426 data_fvec0 = fVec::loadu(input_data + d, size - d); 427 } 428 fVec output_fvec0 = vec_fun(data_fvec0); 429 fVec output_fvec1 = vec_fun(data_fvec1); 430 bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1); 431 output_bvec.store(output_data + d, size - d); 432 } 433 } 434 435 template <typename scalar_t, typename Op, 436 typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> 437 inline void map2( 438 const Op& vec_fun, 439 scalar_t* output_data, 440 const scalar_t* input_data, 441 const scalar_t* input_data2, 442 int64_t size) { 443 using bVec = vec::Vectorized<scalar_t>; 444 using fVec = vec::Vectorized<float>; 445 int64_t d = 0; 446 for (; d < size - (size % bVec::size()); d += bVec::size()) { 447 bVec data_bvec = bVec::loadu(input_data + d); 448 auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); 449 bVec data2_bvec = bVec::loadu(input_data2 + d); 450 auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec); 451 fVec output_fvec0 = vec_fun(data_fvec0, data2_fvec0); 452 fVec output_fvec1 = vec_fun(data_fvec1, data2_fvec1); 453 bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1); 454 output_bvec.store(output_data + d); 455 } 456 if (size - d > 0) { 457 bVec data_bvec = bVec::loadu(input_data + d, size - d); 458 auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); 459 bVec data2_bvec = bVec::loadu(input_data2 + d, size - d); 460 auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec); 461 fVec output_fvec0 = vec_fun(data_fvec0, data2_fvec0); 462 fVec output_fvec1 = vec_fun(data_fvec1, data2_fvec1); 463 bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1); 464 output_bvec.store(output_data + d, size - d); 465 } 466 } 467 468 template <typename scalar_t, typename Op, 469 typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> 470 inline void map3( 471 const Op& vec_fun, 472 scalar_t* output_data, 473 const scalar_t* input_data1, 474 const scalar_t* input_data2, 475 const scalar_t* input_data3, 476 int64_t size) { 477 using bVec = vec::Vectorized<scalar_t>; 478 using fVec = vec::Vectorized<float>; 479 int64_t d = 0; 480 for (; d < size - (size % bVec::size()); d += bVec::size()) { 481 bVec data1_bvec = bVec::loadu(input_data1 + d); 482 auto [data1_fvec0, data1_fvec1] = convert_to_float<scalar_t>(data1_bvec); 483 bVec data2_bvec = bVec::loadu(input_data2 + d); 484 auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec); 485 bVec data3_bvec = bVec::loadu(input_data3 + d); 486 auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec); 487 fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0); 488 fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1); 489 bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1); 490 output_bvec.store(output_data + d); 491 } 492 if (size - d > 0) { 493 bVec data1_bvec = bVec::loadu(input_data1 + d, size - d); 494 auto [data1_fvec0, data1_fvec1] = convert_to_float<scalar_t>(data1_bvec); 495 bVec data2_bvec = bVec::loadu(input_data2 + d, size - d); 496 auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec); 497 bVec data3_bvec = bVec::loadu(input_data3 + d, size - d); 498 auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec); 499 fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0); 500 fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1); 501 bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1); 502 output_bvec.store(output_data + d, size - d); 503 } 504 } 505 506 template <typename scalar_t, typename Op, 507 typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> 508 inline void map4( 509 const Op& vec_fun, 510 scalar_t* output_data, 511 const scalar_t* input_data1, 512 const scalar_t* input_data2, 513 const scalar_t* input_data3, 514 const scalar_t* input_data4, 515 int64_t size) { 516 using bVec = vec::Vectorized<scalar_t>; 517 using fVec = vec::Vectorized<float>; 518 int64_t d = 0; 519 for (; d < size - (size % bVec::size()); d += bVec::size()) { 520 bVec data1_bvec = bVec::loadu(input_data1 + d); 521 auto [data1_fvec0, data1_fvec1] = convert_to_float<scalar_t>(data1_bvec); 522 bVec data2_bvec = bVec::loadu(input_data2 + d); 523 auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec); 524 bVec data3_bvec = bVec::loadu(input_data3 + d); 525 auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec); 526 bVec data4_bvec = bVec::loadu(input_data4 + d); 527 auto [data4_fvec0, data4_fvec1] = convert_to_float<scalar_t>(data4_bvec); 528 fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0); 529 fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1); 530 bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1); 531 output_bvec.store(output_data + d); 532 } 533 if (size - d > 0) { 534 bVec data1_bvec = bVec::loadu(input_data1 + d, size - d); 535 auto [data1_fvec0, data1_fvec1] = convert_to_float<scalar_t>(data1_bvec); 536 bVec data2_bvec = bVec::loadu(input_data2 + d, size - d); 537 auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec); 538 bVec data3_bvec = bVec::loadu(input_data3 + d, size - d); 539 auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec); 540 bVec data4_bvec = bVec::loadu(input_data4 + d, size - d); 541 auto [data4_fvec0, data4_fvec1] = convert_to_float<scalar_t>(data4_bvec); 542 fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0); 543 fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1); 544 bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1); 545 output_bvec.store(output_data + d, size - d); 546 } 547 } 548 549 } // namespace at::vec 550