xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/SumKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/AccumulateType.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/native/ReduceOps.h>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/native/cpu/Reduce.h>
7 #include <ATen/native/cpu/utils.h>
8 #include <c10/util/irange.h>
9 #include <ATen/cpu/vec/functional.h>
10 #include <algorithm>
11 
12 namespace at::native {
13 namespace {
14 
15 // Load vector from a smaller type (more elements) to a larger type (fewer elements),
16 // reducing neighboring elements until it fits into the vector size.
17 template <typename acc_t, typename scalar_t, typename F>
load_reduce_vec(const scalar_t * data,F reduce,acc_t ident)18 Vectorized<acc_t> load_reduce_vec(const scalar_t* data, F reduce, acc_t ident) {
19   using vec_t = Vectorized<scalar_t>;
20   using vacc_t = Vectorized<acc_t>;
21   static_assert(vacc_t::size() <= vec_t::size());
22   const auto val = vec_t::loadu(data);
23   alignas(64) std::array<scalar_t, vec_t::size()> values;
24   val.store(values.data());
25 
26   constexpr int vstride = vec_t::size() / vacc_t::size();
27   alignas(64) std::array<acc_t, vacc_t::size()> acc;
28   acc.fill(ident);
29   for (const auto k : c10::irange(vstride)) {
30     for (const auto i : c10::irange(vacc_t::size())) {
31       acc[i] = reduce(acc[i], values[i * vstride + k]);
32     }
33   }
34 
35   return vacc_t::loadu(acc.data());
36 }
37 
38 template <typename scalar_t>
39 struct LoadPolicy {
memsizeat::native::__anon002d83b80111::LoadPolicy40   static constexpr int64_t memsize() {
41     return sizeof(scalar_t);
42   }
43 
loadat::native::__anon002d83b80111::LoadPolicy44   static scalar_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
45     auto *ptr = reinterpret_cast<const scalar_t*>(data + index * stride);
46     return *ptr;
47   }
48 };
49 
50 template <typename scalar_t>
51 struct LoadPolicy<Vectorized<scalar_t>> {
memsizeat::native::__anon002d83b80111::LoadPolicy52   static constexpr int64_t memsize() {
53     return sizeof(scalar_t) * Vectorized<scalar_t>::size();
54   }
55 
loadat::native::__anon002d83b80111::LoadPolicy56   static Vectorized<scalar_t> load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
57     auto *ptr = data + index * stride;
58     return Vectorized<scalar_t>::loadu(ptr);
59   }
60 };
61 
62 /* When summing float16 or BFloat16, addition has to be performed in float since
63  * that's all the hardware supports. These cast-load policies ensure the entire sum
64  * loop is done in float which improves both performance and accuracy.
65  */
66 
67 template <typename scalar_t, typename acc_t>
68 struct CastLoadPolicy {
memsizeat::native::__anon002d83b80111::CastLoadPolicy69   static constexpr int64_t memsize() {
70     return sizeof(scalar_t);
71   }
72 
loadat::native::__anon002d83b80111::CastLoadPolicy73   static acc_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
74     const auto val = LoadPolicy<scalar_t>::load(data, stride, index);
75     return acc_t(val);
76   }
77 };
78 
79 template <typename scalar_t>
80 struct CastLoadPolicy<scalar_t, scalar_t>:
81     LoadPolicy<scalar_t> {
82 };
83 
84 // For inner sum, load full vec_t then sum partials down to vacc_t size
85 template <typename vec_t, typename vacc_t, typename = void>
86 struct InnerSumCastLoadPolicy;
87 
88 template <typename vec_t, typename vacc_t>
89 struct InnerSumCastLoadPolicy <vec_t, vacc_t,
90   std::enable_if_t<(!is_reduced_floating_point_v<vechold_type<vec_t>>) &&
91                     !std::is_same_v<vec_t, vacc_t>>> {
92   using scalar_t = vechold_type<vec_t>;
93   using acc_t = vechold_type<vacc_t>;
94 
memsizeat::native::__anon002d83b80111::InnerSumCastLoadPolicy95   static constexpr int64_t memsize() {
96     return LoadPolicy<vec_t>::memsize();
97   }
98 
loadat::native::__anon002d83b80111::InnerSumCastLoadPolicy99   static vacc_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
100     auto ptr = reinterpret_cast<const scalar_t*>(data + stride * index);
101     return load_reduce_vec<acc_t>(ptr, [](acc_t a, scalar_t b) {
102       return a + b;
103     }, acc_t(0));
104   }
105 };
106 
107 template <typename scalar_t>
108 struct InnerSumCastLoadPolicy<scalar_t, scalar_t, void>:
109     LoadPolicy<scalar_t> {
110 };
111 
112 template <typename vec_t, typename vacc_t>
113 struct InnerSumCastLoadPolicy <vec_t, vacc_t, std::enable_if_t<is_reduced_floating_point_v<vechold_type<vec_t>>>> {
114   using scalar_t = vechold_type<vec_t>;
115 
memsizeat::native::__anon002d83b80111::InnerSumCastLoadPolicy116   static constexpr int64_t memsize() {
117     return LoadPolicy<vec_t>::memsize();
118   }
119 
loadat::native::__anon002d83b80111::InnerSumCastLoadPolicy120   static vacc_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
121     auto ptr = reinterpret_cast<const scalar_t*>(data + stride * index);
122     vacc_t first, second;
123     vec::load_to_float<scalar_t>(ptr, first, second);
124     return first + second;
125   }
126 };
127 
128 // For outer sum, load a partial vec_t of size vacc_t then cast to vacc_t
129 template <typename vec_t, typename vacc_t, typename = void>
130 struct OuterSumCastLoadPolicy;
131 
132 template <typename vec_t, typename vacc_t>
133 struct OuterSumCastLoadPolicy <vec_t, vacc_t,
134   std::enable_if_t<(!is_reduced_floating_point_v<vechold_type<vec_t>>) &&
135                     !std::is_same_v<vec_t, vacc_t>>> {
136 
137   using scalar_t = vechold_type<vec_t>;
138   using acc_t = vechold_type<vacc_t>;
139 
memsizeat::native::__anon002d83b80111::OuterSumCastLoadPolicy140   static constexpr int64_t memsize() {
141     return sizeof(scalar_t) * vacc_t::size();
142   }
143 
loadat::native::__anon002d83b80111::OuterSumCastLoadPolicy144   static vacc_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
145     static_assert(vacc_t::size() <= vec_t::size());
146     const auto val = vec_t::loadu(data + stride * index, vacc_t::size());
147     alignas(64) scalar_t values[vec_t::size()];
148     val.store(values);
149 
150     alignas(64) acc_t acc[vacc_t::size()];
151     for (const auto i : c10::irange(vacc_t::size())) {
152       acc[i] = values[i];
153     }
154 
155     return vacc_t::loadu(acc);
156   }
157 };
158 
159 template <typename vec_t, typename vacc_t>
160 struct OuterSumCastLoadPolicy <vec_t, vacc_t, std::enable_if_t<is_reduced_floating_point_v<vechold_type<vec_t>>>> {
161   using scalar_t = vechold_type<vec_t>;
162 
memsizeat::native::__anon002d83b80111::OuterSumCastLoadPolicy163   static constexpr int64_t memsize() {
164     return sizeof(scalar_t) * vacc_t::size();
165   }
166 
loadat::native::__anon002d83b80111::OuterSumCastLoadPolicy167   static vacc_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
168     auto ptr = reinterpret_cast<const scalar_t*>(data + stride * index);
169     vacc_t values;
170     vec::load_to_float<scalar_t>(ptr, values);
171     return values;
172   }
173 };
174 
175 template <typename scalar_t>
176 struct OuterSumCastLoadPolicy<scalar_t, scalar_t, void>:
177     LoadPolicy<scalar_t> {
178 };
179 
180 /* To implement nansum, augment the load operation to mask out nans before
181  * entering the normal sum loop.
182  */
183 
184 template <typename scalar_t>
185 struct NanSumLoadPolicy {
memsizeat::native::__anon002d83b80111::NanSumLoadPolicy186   static constexpr int64_t memsize() {
187     return sizeof(scalar_t);
188   }
189 
loadat::native::__anon002d83b80111::NanSumLoadPolicy190   static scalar_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
191     auto val = LoadPolicy<scalar_t>::load(data, stride, index);
192     return at::_isnan(val) ? scalar_t(0) : val;
193   }
194 };
195 
196 template <typename scalar_t>
197 struct NanSumLoadPolicy<Vectorized<scalar_t>> {
198   using vec_t = Vectorized<scalar_t>;
199 
memsizeat::native::__anon002d83b80111::NanSumLoadPolicy200   static constexpr int64_t memsize() {
201     return LoadPolicy<vec_t>::memsize();
202   }
203 
loadat::native::__anon002d83b80111::NanSumLoadPolicy204   static vec_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
205     auto val = LoadPolicy<vec_t>::load(data, stride, index);
206     return vec_t::blendv(val, vec_t(0), val.isnan());
207   }
208 };
209 
210 template <typename scalar_t, typename acc_t>
211 struct NanSumCastLoadPolicy {
memsizeat::native::__anon002d83b80111::NanSumCastLoadPolicy212   static constexpr int64_t memsize() {
213     return sizeof(scalar_t);
214   }
215 
loadat::native::__anon002d83b80111::NanSumCastLoadPolicy216   static acc_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
217     auto val = CastLoadPolicy<scalar_t, acc_t>::load(data, stride, index);
218     return at::_isnan(val) ? acc_t(0) : val;
219   }
220 };
221 
222 template <typename vec_t, typename vacc_t, typename = void>
223 struct InnerNanSumCastLoadPolicy;
224 
225 template <typename vec_t, typename vacc_t>
226 struct InnerNanSumCastLoadPolicy <vec_t, vacc_t,
227   std::enable_if_t<(!is_reduced_floating_point_v<vechold_type<vec_t>>) &&
228                     !std::is_same_v<vec_t, vacc_t>>> {
229   using scalar_t = vechold_type<vec_t>;
230   using acc_t = vechold_type<vacc_t>;
231 
memsizeat::native::__anon002d83b80111::InnerNanSumCastLoadPolicy232   static constexpr int64_t memsize() {
233     return LoadPolicy<vec_t>::memsize();
234   }
235 
loadat::native::__anon002d83b80111::InnerNanSumCastLoadPolicy236   static vacc_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
237     auto ptr = reinterpret_cast<const scalar_t*>(data + stride * index);
238     return load_reduce_vec<acc_t>(ptr, [](acc_t a, scalar_t b) {
239       return at::_isnan(b) ? a : a + b;
240     }, acc_t(0));
241   }
242 };
243 
244 template <typename scalar_t>
245 struct InnerNanSumCastLoadPolicy<scalar_t, scalar_t, void>:
246     NanSumLoadPolicy<scalar_t> {
247 };
248 
249 template <typename vec_t, typename vacc_t>
250 struct InnerNanSumCastLoadPolicy <vec_t, vacc_t, std::enable_if_t<is_reduced_floating_point_v<vechold_type<vec_t>>>> {
251   using scalar_t = vechold_type<vec_t>;
252 
memsizeat::native::__anon002d83b80111::InnerNanSumCastLoadPolicy253   static constexpr int64_t memsize() {
254     return LoadPolicy<vec_t>::memsize();
255   }
256 
loadat::native::__anon002d83b80111::InnerNanSumCastLoadPolicy257   static vacc_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
258     auto ptr = reinterpret_cast<const scalar_t*>(data + stride * index);
259     vacc_t first, second;
260     vec::load_to_float<scalar_t>(ptr, first, second);
261     const vacc_t zero(0);
262     return (vacc_t::blendv(first, zero, first.isnan()) +
263             vacc_t::blendv(second, zero, second.isnan()));
264   }
265 };
266 
267 template <typename vec_t, typename vacc_t>
268 struct OuterNanSumCastLoadPolicy {
memsizeat::native::__anon002d83b80111::OuterNanSumCastLoadPolicy269   static constexpr int64_t memsize() {
270     return OuterSumCastLoadPolicy<vec_t, vacc_t>::memsize();
271   }
272 
loadat::native::__anon002d83b80111::OuterNanSumCastLoadPolicy273   static vacc_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
274     auto val = OuterSumCastLoadPolicy<vec_t, vacc_t>::load(data, stride, index);
275     return vacc_t::blendv(val, vacc_t(0), val.isnan());
276   }
277 };
278 
279 template <typename scalar_t, typename acc_t>
280 struct CastStoreAccumulate {
storeat::native::__anon002d83b80111::CastStoreAccumulate281   static void store(char * C10_RESTRICT data, int64_t stride, int64_t index, acc_t value) {
282     auto * ptr = reinterpret_cast<scalar_t*>(data + index * stride);
283     *ptr += value;
284   }
285 };
286 
287 template <typename StorePolicy, typename scalar_t>
store(char * C10_RESTRICT data,int64_t stride,int64_t index,scalar_t value)288 static void store(char * C10_RESTRICT data, int64_t stride, int64_t index, scalar_t value) {
289   StorePolicy::store(data, stride, index, value);
290 }
291 
292 template <typename StorePolicy, typename scalar_t, size_t numel>
store(char * C10_RESTRICT data,int64_t stride,int64_t index,const std::array<scalar_t,numel> & values)293 static void store(char * C10_RESTRICT data, int64_t stride, int64_t index,
294                   const std::array<scalar_t, numel> &values) {
295   auto *base_ptr = data + stride * index;
296   for (const auto k : c10::irange(numel)) {
297     auto val = values[k];
298     StorePolicy::store(base_ptr, stride, k, val);
299   }
300 }
301 
302 template <typename StorePolicy, typename scalar_t>
store(char * C10_RESTRICT data,int64_t stride,int64_t index,const Vectorized<scalar_t> & values)303 static void store(char * C10_RESTRICT data, int64_t stride, int64_t index,
304                   const Vectorized<scalar_t> &values) {
305   using vec_t = Vectorized<scalar_t>;
306   alignas(64) std::array<scalar_t, vec_t::size()> array_values{};
307   values.store(array_values.data());
308   store<StorePolicy>(data, stride, index, array_values);
309 }
310 
311 /** Simultaneously sum over n rows at once
312 
313 This algorithm calculates the sum without loss of precision over large axes. It
314 does this by chunking the sum into groups of 16 or more elements. The sums of
315 these chunks are also summed in chunks and so on until there is just a single sum
316 value remaining. This means only numbers of a similar order of magnitude are
317 added together, thus minimising rounding errors.
318 
319 This is done in a single linear pass over the data and with O(1) extra storage.
320 A simplified recursive implementation would look like this:
321 
322   scalar_t row_sum(const scalar_t * data, int64_t n) {
323     // Note, in practice the chunk size can increase with n
324     // This allows the recursion depth to be limited to O(1).
325     constexpr int64_t min_chunk_size = 16;
326 
327     scalar_t sum = 0;
328     if (n <= min_chunk_size) {
329       // Recursive base case, calculate a simple running sum
330       for (const auto i : c10::irange(n)) {
331         sum += data[i];
332       }
333       return sum;
334     }
335 
336     // Recursively sum larger chunks of elements
337     const int64_t chunk_size = std::max(divup(n, min_chunk_size), min_chunk_size);
338     for (int64_t i = 0; i < n; i += chunk_size) {
339       sum += row_sum(data + i, std::min(chunk_size, n - i));
340     }
341     return sum;
342   }
343 */
344 template <typename scalar_t, int64_t nrows, typename LoadPolicy>
multi_row_sum(const char * C10_RESTRICT in_data,const int64_t row_stride,const int64_t col_stride,const int64_t size)345 std::array<scalar_t, nrows> multi_row_sum(
346     const char * C10_RESTRICT in_data,
347     const int64_t row_stride,
348     const int64_t col_stride,
349     const int64_t size) {
350   constexpr int64_t num_levels = 4;
351 
352   const int64_t level_power =
353       std::max(int64_t(4), utils::CeilLog2(size) / num_levels);
354   const int64_t level_step = (1 << level_power);
355   const int64_t level_mask = level_step - 1;
356 
357   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
358   scalar_t acc[num_levels][nrows];
359   std::fill_n(&acc[0][0], num_levels * nrows, scalar_t(0));
360 
361   int64_t i = 0;
362   for (; i + level_step <= size;) {
363     for (int64_t j = 0; j < level_step; ++j, ++i) {
364       const char * sum_base = in_data + i * row_stride;
365       #if !defined(COMPILING_FOR_MIN_SIZE)
366       # pragma unroll
367       #endif
368       for (const auto k : c10::irange(nrows)) {
369         acc[0][k] += LoadPolicy::load(sum_base, col_stride, k);
370       }
371     }
372 
373     for (const auto j : c10::irange(1, num_levels)) {
374       #if !defined(COMPILING_FOR_MIN_SIZE)
375       # pragma unroll
376       #endif
377       for (const auto k : c10::irange(nrows)) {
378         acc[j][k] += acc[j-1][k];
379         acc[j-1][k] = scalar_t(0);
380       }
381 
382       const auto mask = (level_mask << (j * level_power));
383       if ((i & mask) != 0) {
384         break;
385       }
386     }
387   }
388 
389   for (; i < size; ++i) {
390     const char * sum_base = in_data + i * row_stride;
391     #if !defined(COMPILING_FOR_MIN_SIZE)
392     # pragma unroll
393     #endif
394     for (const auto k : c10::irange(nrows)) {
395       acc[0][k] += LoadPolicy::load(sum_base, col_stride, k);
396     }
397   }
398 
399   for (const auto j : c10::irange(1, num_levels)) {
400     #if !defined(COMPILING_FOR_MIN_SIZE)
401     # pragma unroll
402     #endif
403     for (const auto k : c10::irange(nrows)) {
404       acc[0][k] += acc[j][k];
405     }
406   }
407 
408   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
409   std::array<scalar_t, nrows> ret;
410   for (const auto k : c10::irange(nrows)) {
411     ret[k] = acc[0][k];
412   }
413   return ret;
414 }
415 
416 template <typename scalar_t, typename LoadPolicy>
row_sum(const char * C10_RESTRICT in_data,const int64_t in_stride,const int64_t size)417 scalar_t row_sum(const char * C10_RESTRICT in_data,
418                  const int64_t in_stride, const int64_t size) {
419   constexpr int64_t ilp_factor = 4;
420 
421   // Interpret row as a (-1, ilp_factor) shaped array to find partial sums
422   const int64_t size_ilp = size / ilp_factor;
423   auto partial_sums = multi_row_sum<scalar_t, ilp_factor, LoadPolicy>(
424       in_data, in_stride * ilp_factor, in_stride, size_ilp);
425 
426   for (int64_t i = size_ilp * ilp_factor; i < size; ++i) {
427     partial_sums[0] += LoadPolicy::load(in_data, in_stride, i);
428   }
429 
430   for (const auto k : c10::irange(1, ilp_factor)) {
431     partial_sums[0] += partial_sums[k];
432   }
433 
434   return partial_sums[0];
435 }
436 
437 template <typename acc_t, typename VecLoadPolicy, typename ScalarLoadPolicy, typename StorePolicy>
vectorized_inner_sum(char * C10_RESTRICT data[2],int64_t outer_stride,int64_t out_stride,int64_t size0,int64_t size1)438 void vectorized_inner_sum(
439     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
440     char * C10_RESTRICT data[2], int64_t outer_stride, int64_t out_stride,
441     int64_t size0, int64_t size1) {
442   using vacc_t = Vectorized<acc_t>;
443   constexpr int64_t vec_stride = VecLoadPolicy::memsize();
444   constexpr int64_t scalar_stride = ScalarLoadPolicy::memsize();
445   constexpr int64_t vec_numel = vec_stride / scalar_stride;
446   const int64_t vec_size = size0 / vec_numel;
447 
448   // Input is contiguous over the first (reduced) dimension
449   for (const auto j : c10::irange(size1)) {
450     const auto *row_in = data[1] + j * outer_stride;
451     auto vec_acc = row_sum<vacc_t, VecLoadPolicy>(row_in, vec_stride, vec_size);
452 
453     acc_t final_acc = 0;
454     for (int64_t k = vec_size * vec_numel; k < size0; ++k) {
455       final_acc += ScalarLoadPolicy::load(row_in, scalar_stride, k);
456     }
457 
458     alignas(64) std::array<acc_t, vacc_t::size()> partials{};
459     vec_acc.store(partials.data());
460     for (const auto k : c10::irange(partials.size())) {
461       final_acc += partials[k];
462     }
463     store<StorePolicy>(data[0], out_stride, j, final_acc);
464   }
465 }
466 
467 template <typename acc_t, typename LoadPolicy, typename StorePolicy>
scalar_inner_sum(char * C10_RESTRICT data[2],int64_t in_strides[2],int64_t out_stride,int64_t size0,int64_t size1)468 void scalar_inner_sum(
469     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
470     char * C10_RESTRICT data[2], int64_t in_strides[2], int64_t out_stride,
471     int64_t size0, int64_t size1) {
472   for (const auto j : c10::irange(size1)) {
473     const auto *row_in = data[1] + j * in_strides[1];
474     auto ans = row_sum<acc_t, LoadPolicy>(row_in, in_strides[0], size0);
475     store<StorePolicy>(data[0], out_stride, j, ans);
476   }
477 }
478 
479 template <typename acc_t, typename VecLoadPolicy, typename ScalarLoadPolicy, typename StorePolicy>
vectorized_outer_sum(char * C10_RESTRICT data[2],int64_t inner_stride,int64_t out_stride,int64_t size0,int64_t size1)480 void vectorized_outer_sum(
481     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
482     char * C10_RESTRICT data[2], int64_t inner_stride, int64_t out_stride,
483     int64_t size0, int64_t size1) {
484   using vacc_t = Vectorized<acc_t>;
485   constexpr int64_t scalar_stride = ScalarLoadPolicy::memsize();
486   constexpr int64_t vec_stride = VecLoadPolicy::memsize();
487   constexpr int64_t nrows = 4;
488 
489   // Input is contiguous over the second (non-reduced) dimension
490   int64_t j = 0;
491   for (; j + nrows * vacc_t::size() <= size1; j += nrows * vacc_t::size()) {
492     const auto *row_in = data[1] + j * scalar_stride;
493     auto sums = multi_row_sum<vacc_t, nrows, VecLoadPolicy>(
494         row_in, inner_stride, vec_stride, size0);
495 
496     for (const auto i : c10::irange(nrows)) {
497       const int64_t base_idx = j + i * vacc_t::size();
498       store<StorePolicy>(data[0], out_stride, base_idx, sums[i]);
499     }
500   }
501 
502   for (; j + vacc_t::size() <= size1; j += vacc_t::size()) {
503     const auto *row_in = data[1] + j * scalar_stride;
504     const vacc_t sums = row_sum<vacc_t, VecLoadPolicy>(
505         row_in, inner_stride, size0);
506 
507     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
508     store<StorePolicy>(data[0], out_stride, j, sums);
509   }
510 
511   for (; j < size1; ++j) {
512     const auto *row_in = data[1] + j * scalar_stride;
513     auto ans = row_sum<acc_t, ScalarLoadPolicy>(row_in, inner_stride, size0);
514     store<StorePolicy>(data[0], out_stride, j, ans);
515   }
516 }
517 
518 template <typename acc_t, typename LoadPolicy, typename StorePolicy>
scalar_outer_sum(char * C10_RESTRICT data[2],int64_t in_strides[2],int64_t out_stride,int64_t size0,int64_t size1)519 void scalar_outer_sum(
520     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
521     char * C10_RESTRICT data[2], int64_t in_strides[2], int64_t out_stride,
522     int64_t size0, int64_t size1) {
523   constexpr int64_t nrows = 4;
524   int64_t j = 0;
525   for (; j + (nrows - 1) < size1; j += nrows) {
526     const auto *row_in = data[1] + j * in_strides[1];
527     auto sums = multi_row_sum<acc_t, nrows, LoadPolicy>(
528         row_in, in_strides[0], in_strides[1], size0);
529     store<StorePolicy>(data[0], out_stride, j, sums);
530   }
531 
532   for (; j < size1; ++j) {
533     const auto *row_in = data[1] + j * in_strides[1];
534     auto ans = row_sum<acc_t, LoadPolicy>(
535         row_in, in_strides[0], size0);
536     store<StorePolicy>(data[0], out_stride, j, ans);
537   }
538 }
539 
540 // Custom floating point sum for better accuracy
541 template <bool ignore_nan, typename scalar_t>
cascade_sum(TensorIterator & iter)542 void cascade_sum(TensorIterator &iter) {
543   iter.output_base().fill_(scalar_t(0));
544   iter.parallel_reduce(
545     [&](char** data, const int64_t* strides, int64_t size0, int64_t size1) {
546       // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
547       int64_t in_strides[] = { strides[1], strides[3] };
548       // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
549       int64_t out_strides[] = { strides[0], strides[2] };
550 
551       // Move reduction to be the 1st dim
552       if (out_strides[0] != 0 && out_strides[1] == 0) {
553         std::swap(in_strides[0], in_strides[1]);
554         std::swap(out_strides[0], out_strides[1]);
555         std::swap(size0, size1);
556       }
557 
558       // Special case? - not a true reduction
559       if (out_strides[0] != 0 && out_strides[1] != 0) {
560         // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
561         int64_t outer_strides[] = { strides[2], strides[3] };
562         UNARY_OUTER_LOOP(data, outer_strides, size1, [&] {
563           // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
564           char* ptrs[3] = { data[0], data[0], data[1] };
565           // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
566           int64_t inner_strides[3] = { strides[0], strides[0], strides[1] };
567           if constexpr (ignore_nan) {
568               basic_loop(ptrs, inner_strides, 0, size0, [](scalar_t a, scalar_t b) {
569                 auto a_notnan = at::_isnan(a) ? scalar_t(0) : a;
570                 auto b_notnan = at::_isnan(b) ? scalar_t(0) : b;
571                 return a_notnan + b_notnan;
572               });
573           } else {
574               basic_loop(ptrs, inner_strides, 0, size0,
575                          [](scalar_t a, scalar_t b) { return a + b; });
576           }
577         });
578         return;
579       }
580 
581       const int64_t out_stride = out_strides[1];
582       TORCH_INTERNAL_ASSERT(out_strides[0] == 0);
583 
584       using vec_t = Vectorized<scalar_t>;
585       using acc_t = at::acc_type<scalar_t, true>;
586       using vacc_t = Vectorized<acc_t>;
587       using ScalarLoadPolicy = std::conditional_t<
588           ignore_nan,
589           NanSumCastLoadPolicy<scalar_t, acc_t>,
590           CastLoadPolicy<scalar_t, acc_t>>;
591       using StorePolicy = CastStoreAccumulate<scalar_t, acc_t>;
592 
593       if (in_strides[0] == sizeof(scalar_t) && size0 >= vec_t::size()) {
594         // Contiguous inner reduction
595         using VecLoadPolicy = std::conditional_t<
596             ignore_nan,
597             InnerNanSumCastLoadPolicy<vec_t, vacc_t>,
598             InnerSumCastLoadPolicy<vec_t, vacc_t>>;
599         vectorized_inner_sum<acc_t, VecLoadPolicy, ScalarLoadPolicy, StorePolicy>(
600             data, in_strides[1], out_stride, size0, size1);
601       } else if (in_strides[1] == sizeof(scalar_t) && size1 >= vec_t::size()) {
602         // Contiguous outer reduction
603         using VecLoadPolicy = std::conditional_t<
604             ignore_nan,
605             OuterNanSumCastLoadPolicy<vec_t, vacc_t>,
606             OuterSumCastLoadPolicy<vec_t, vacc_t>>;
607         vectorized_outer_sum<acc_t, VecLoadPolicy, ScalarLoadPolicy, StorePolicy>(
608             data, in_strides[0], out_stride, size0, size1);
609       } else if (in_strides[0] < in_strides[1]) {
610         scalar_inner_sum<acc_t, ScalarLoadPolicy, StorePolicy>(
611             data, in_strides, out_stride, size0, size1);
612       } else {
613         scalar_outer_sum<acc_t, ScalarLoadPolicy, StorePolicy>(
614             data, in_strides, out_stride, size0, size1);
615       }
616     });
617 }
618 
sum_kernel_impl(TensorIterator & iter)619 void sum_kernel_impl(TensorIterator &iter) {
620   if (isIntegralType(iter.dtype(), /*includeBool=*/ true)) {
621     AT_DISPATCH_INTEGRAL_TYPES_AND(ScalarType::Bool, iter.dtype(), "sum_cpu",
622       [&] {
623         binary_kernel_reduce_vec(
624             iter, [=](scalar_t a, scalar_t b) -> scalar_t { return a + b; },
625             [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return a + b; });
626       });
627     return;
628   }
629 
630   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
631       ScalarType::BFloat16, ScalarType::Half, iter.dtype(), "sum_cpu", [&] {
632     cascade_sum</*ignore_nan=*/false, scalar_t>(iter);
633   });
634 }
635 
nansum_kernel_impl(TensorIterator & iter)636 void nansum_kernel_impl(TensorIterator &iter) {
637   AT_DISPATCH_FLOATING_TYPES_AND2(
638       ScalarType::BFloat16, ScalarType::Half, iter.dtype(), "nansum_cpu", [&] {
639     cascade_sum</*ignore_nan=*/true, scalar_t>(iter);
640   });
641 }
642 
643 }  // namespace (anonymous)
644 
645 // nansum on Float16 has poor accuracy with AVX2, and more so with AVX512.
646 // So until it's fixed, it won't be dispatched with AVX512. GH issue 59415.
647 // Besides, these kernels are slower with AVX512 than with AVX2.
648 REGISTER_DISPATCH(nansum_stub, &nansum_kernel_impl);
649 REGISTER_DISPATCH(sum_stub, &sum_kernel_impl);
650 
651 }  // namespace at::native
652