1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/AccumulateType.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/native/ReduceOps.h>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/native/cpu/Reduce.h>
7 #include <ATen/native/cpu/utils.h>
8 #include <c10/util/irange.h>
9 #include <ATen/cpu/vec/functional.h>
10 #include <algorithm>
11
12 namespace at::native {
13 namespace {
14
15 // Load vector from a smaller type (more elements) to a larger type (fewer elements),
16 // reducing neighboring elements until it fits into the vector size.
17 template <typename acc_t, typename scalar_t, typename F>
load_reduce_vec(const scalar_t * data,F reduce,acc_t ident)18 Vectorized<acc_t> load_reduce_vec(const scalar_t* data, F reduce, acc_t ident) {
19 using vec_t = Vectorized<scalar_t>;
20 using vacc_t = Vectorized<acc_t>;
21 static_assert(vacc_t::size() <= vec_t::size());
22 const auto val = vec_t::loadu(data);
23 alignas(64) std::array<scalar_t, vec_t::size()> values;
24 val.store(values.data());
25
26 constexpr int vstride = vec_t::size() / vacc_t::size();
27 alignas(64) std::array<acc_t, vacc_t::size()> acc;
28 acc.fill(ident);
29 for (const auto k : c10::irange(vstride)) {
30 for (const auto i : c10::irange(vacc_t::size())) {
31 acc[i] = reduce(acc[i], values[i * vstride + k]);
32 }
33 }
34
35 return vacc_t::loadu(acc.data());
36 }
37
38 template <typename scalar_t>
39 struct LoadPolicy {
memsizeat::native::__anon002d83b80111::LoadPolicy40 static constexpr int64_t memsize() {
41 return sizeof(scalar_t);
42 }
43
loadat::native::__anon002d83b80111::LoadPolicy44 static scalar_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
45 auto *ptr = reinterpret_cast<const scalar_t*>(data + index * stride);
46 return *ptr;
47 }
48 };
49
50 template <typename scalar_t>
51 struct LoadPolicy<Vectorized<scalar_t>> {
memsizeat::native::__anon002d83b80111::LoadPolicy52 static constexpr int64_t memsize() {
53 return sizeof(scalar_t) * Vectorized<scalar_t>::size();
54 }
55
loadat::native::__anon002d83b80111::LoadPolicy56 static Vectorized<scalar_t> load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
57 auto *ptr = data + index * stride;
58 return Vectorized<scalar_t>::loadu(ptr);
59 }
60 };
61
62 /* When summing float16 or BFloat16, addition has to be performed in float since
63 * that's all the hardware supports. These cast-load policies ensure the entire sum
64 * loop is done in float which improves both performance and accuracy.
65 */
66
67 template <typename scalar_t, typename acc_t>
68 struct CastLoadPolicy {
memsizeat::native::__anon002d83b80111::CastLoadPolicy69 static constexpr int64_t memsize() {
70 return sizeof(scalar_t);
71 }
72
loadat::native::__anon002d83b80111::CastLoadPolicy73 static acc_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
74 const auto val = LoadPolicy<scalar_t>::load(data, stride, index);
75 return acc_t(val);
76 }
77 };
78
79 template <typename scalar_t>
80 struct CastLoadPolicy<scalar_t, scalar_t>:
81 LoadPolicy<scalar_t> {
82 };
83
84 // For inner sum, load full vec_t then sum partials down to vacc_t size
85 template <typename vec_t, typename vacc_t, typename = void>
86 struct InnerSumCastLoadPolicy;
87
88 template <typename vec_t, typename vacc_t>
89 struct InnerSumCastLoadPolicy <vec_t, vacc_t,
90 std::enable_if_t<(!is_reduced_floating_point_v<vechold_type<vec_t>>) &&
91 !std::is_same_v<vec_t, vacc_t>>> {
92 using scalar_t = vechold_type<vec_t>;
93 using acc_t = vechold_type<vacc_t>;
94
memsizeat::native::__anon002d83b80111::InnerSumCastLoadPolicy95 static constexpr int64_t memsize() {
96 return LoadPolicy<vec_t>::memsize();
97 }
98
loadat::native::__anon002d83b80111::InnerSumCastLoadPolicy99 static vacc_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
100 auto ptr = reinterpret_cast<const scalar_t*>(data + stride * index);
101 return load_reduce_vec<acc_t>(ptr, [](acc_t a, scalar_t b) {
102 return a + b;
103 }, acc_t(0));
104 }
105 };
106
107 template <typename scalar_t>
108 struct InnerSumCastLoadPolicy<scalar_t, scalar_t, void>:
109 LoadPolicy<scalar_t> {
110 };
111
112 template <typename vec_t, typename vacc_t>
113 struct InnerSumCastLoadPolicy <vec_t, vacc_t, std::enable_if_t<is_reduced_floating_point_v<vechold_type<vec_t>>>> {
114 using scalar_t = vechold_type<vec_t>;
115
memsizeat::native::__anon002d83b80111::InnerSumCastLoadPolicy116 static constexpr int64_t memsize() {
117 return LoadPolicy<vec_t>::memsize();
118 }
119
loadat::native::__anon002d83b80111::InnerSumCastLoadPolicy120 static vacc_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
121 auto ptr = reinterpret_cast<const scalar_t*>(data + stride * index);
122 vacc_t first, second;
123 vec::load_to_float<scalar_t>(ptr, first, second);
124 return first + second;
125 }
126 };
127
128 // For outer sum, load a partial vec_t of size vacc_t then cast to vacc_t
129 template <typename vec_t, typename vacc_t, typename = void>
130 struct OuterSumCastLoadPolicy;
131
132 template <typename vec_t, typename vacc_t>
133 struct OuterSumCastLoadPolicy <vec_t, vacc_t,
134 std::enable_if_t<(!is_reduced_floating_point_v<vechold_type<vec_t>>) &&
135 !std::is_same_v<vec_t, vacc_t>>> {
136
137 using scalar_t = vechold_type<vec_t>;
138 using acc_t = vechold_type<vacc_t>;
139
memsizeat::native::__anon002d83b80111::OuterSumCastLoadPolicy140 static constexpr int64_t memsize() {
141 return sizeof(scalar_t) * vacc_t::size();
142 }
143
loadat::native::__anon002d83b80111::OuterSumCastLoadPolicy144 static vacc_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
145 static_assert(vacc_t::size() <= vec_t::size());
146 const auto val = vec_t::loadu(data + stride * index, vacc_t::size());
147 alignas(64) scalar_t values[vec_t::size()];
148 val.store(values);
149
150 alignas(64) acc_t acc[vacc_t::size()];
151 for (const auto i : c10::irange(vacc_t::size())) {
152 acc[i] = values[i];
153 }
154
155 return vacc_t::loadu(acc);
156 }
157 };
158
159 template <typename vec_t, typename vacc_t>
160 struct OuterSumCastLoadPolicy <vec_t, vacc_t, std::enable_if_t<is_reduced_floating_point_v<vechold_type<vec_t>>>> {
161 using scalar_t = vechold_type<vec_t>;
162
memsizeat::native::__anon002d83b80111::OuterSumCastLoadPolicy163 static constexpr int64_t memsize() {
164 return sizeof(scalar_t) * vacc_t::size();
165 }
166
loadat::native::__anon002d83b80111::OuterSumCastLoadPolicy167 static vacc_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
168 auto ptr = reinterpret_cast<const scalar_t*>(data + stride * index);
169 vacc_t values;
170 vec::load_to_float<scalar_t>(ptr, values);
171 return values;
172 }
173 };
174
175 template <typename scalar_t>
176 struct OuterSumCastLoadPolicy<scalar_t, scalar_t, void>:
177 LoadPolicy<scalar_t> {
178 };
179
180 /* To implement nansum, augment the load operation to mask out nans before
181 * entering the normal sum loop.
182 */
183
184 template <typename scalar_t>
185 struct NanSumLoadPolicy {
memsizeat::native::__anon002d83b80111::NanSumLoadPolicy186 static constexpr int64_t memsize() {
187 return sizeof(scalar_t);
188 }
189
loadat::native::__anon002d83b80111::NanSumLoadPolicy190 static scalar_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
191 auto val = LoadPolicy<scalar_t>::load(data, stride, index);
192 return at::_isnan(val) ? scalar_t(0) : val;
193 }
194 };
195
196 template <typename scalar_t>
197 struct NanSumLoadPolicy<Vectorized<scalar_t>> {
198 using vec_t = Vectorized<scalar_t>;
199
memsizeat::native::__anon002d83b80111::NanSumLoadPolicy200 static constexpr int64_t memsize() {
201 return LoadPolicy<vec_t>::memsize();
202 }
203
loadat::native::__anon002d83b80111::NanSumLoadPolicy204 static vec_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
205 auto val = LoadPolicy<vec_t>::load(data, stride, index);
206 return vec_t::blendv(val, vec_t(0), val.isnan());
207 }
208 };
209
210 template <typename scalar_t, typename acc_t>
211 struct NanSumCastLoadPolicy {
memsizeat::native::__anon002d83b80111::NanSumCastLoadPolicy212 static constexpr int64_t memsize() {
213 return sizeof(scalar_t);
214 }
215
loadat::native::__anon002d83b80111::NanSumCastLoadPolicy216 static acc_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
217 auto val = CastLoadPolicy<scalar_t, acc_t>::load(data, stride, index);
218 return at::_isnan(val) ? acc_t(0) : val;
219 }
220 };
221
222 template <typename vec_t, typename vacc_t, typename = void>
223 struct InnerNanSumCastLoadPolicy;
224
225 template <typename vec_t, typename vacc_t>
226 struct InnerNanSumCastLoadPolicy <vec_t, vacc_t,
227 std::enable_if_t<(!is_reduced_floating_point_v<vechold_type<vec_t>>) &&
228 !std::is_same_v<vec_t, vacc_t>>> {
229 using scalar_t = vechold_type<vec_t>;
230 using acc_t = vechold_type<vacc_t>;
231
memsizeat::native::__anon002d83b80111::InnerNanSumCastLoadPolicy232 static constexpr int64_t memsize() {
233 return LoadPolicy<vec_t>::memsize();
234 }
235
loadat::native::__anon002d83b80111::InnerNanSumCastLoadPolicy236 static vacc_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
237 auto ptr = reinterpret_cast<const scalar_t*>(data + stride * index);
238 return load_reduce_vec<acc_t>(ptr, [](acc_t a, scalar_t b) {
239 return at::_isnan(b) ? a : a + b;
240 }, acc_t(0));
241 }
242 };
243
244 template <typename scalar_t>
245 struct InnerNanSumCastLoadPolicy<scalar_t, scalar_t, void>:
246 NanSumLoadPolicy<scalar_t> {
247 };
248
249 template <typename vec_t, typename vacc_t>
250 struct InnerNanSumCastLoadPolicy <vec_t, vacc_t, std::enable_if_t<is_reduced_floating_point_v<vechold_type<vec_t>>>> {
251 using scalar_t = vechold_type<vec_t>;
252
memsizeat::native::__anon002d83b80111::InnerNanSumCastLoadPolicy253 static constexpr int64_t memsize() {
254 return LoadPolicy<vec_t>::memsize();
255 }
256
loadat::native::__anon002d83b80111::InnerNanSumCastLoadPolicy257 static vacc_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
258 auto ptr = reinterpret_cast<const scalar_t*>(data + stride * index);
259 vacc_t first, second;
260 vec::load_to_float<scalar_t>(ptr, first, second);
261 const vacc_t zero(0);
262 return (vacc_t::blendv(first, zero, first.isnan()) +
263 vacc_t::blendv(second, zero, second.isnan()));
264 }
265 };
266
267 template <typename vec_t, typename vacc_t>
268 struct OuterNanSumCastLoadPolicy {
memsizeat::native::__anon002d83b80111::OuterNanSumCastLoadPolicy269 static constexpr int64_t memsize() {
270 return OuterSumCastLoadPolicy<vec_t, vacc_t>::memsize();
271 }
272
loadat::native::__anon002d83b80111::OuterNanSumCastLoadPolicy273 static vacc_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
274 auto val = OuterSumCastLoadPolicy<vec_t, vacc_t>::load(data, stride, index);
275 return vacc_t::blendv(val, vacc_t(0), val.isnan());
276 }
277 };
278
279 template <typename scalar_t, typename acc_t>
280 struct CastStoreAccumulate {
storeat::native::__anon002d83b80111::CastStoreAccumulate281 static void store(char * C10_RESTRICT data, int64_t stride, int64_t index, acc_t value) {
282 auto * ptr = reinterpret_cast<scalar_t*>(data + index * stride);
283 *ptr += value;
284 }
285 };
286
287 template <typename StorePolicy, typename scalar_t>
store(char * C10_RESTRICT data,int64_t stride,int64_t index,scalar_t value)288 static void store(char * C10_RESTRICT data, int64_t stride, int64_t index, scalar_t value) {
289 StorePolicy::store(data, stride, index, value);
290 }
291
292 template <typename StorePolicy, typename scalar_t, size_t numel>
store(char * C10_RESTRICT data,int64_t stride,int64_t index,const std::array<scalar_t,numel> & values)293 static void store(char * C10_RESTRICT data, int64_t stride, int64_t index,
294 const std::array<scalar_t, numel> &values) {
295 auto *base_ptr = data + stride * index;
296 for (const auto k : c10::irange(numel)) {
297 auto val = values[k];
298 StorePolicy::store(base_ptr, stride, k, val);
299 }
300 }
301
302 template <typename StorePolicy, typename scalar_t>
store(char * C10_RESTRICT data,int64_t stride,int64_t index,const Vectorized<scalar_t> & values)303 static void store(char * C10_RESTRICT data, int64_t stride, int64_t index,
304 const Vectorized<scalar_t> &values) {
305 using vec_t = Vectorized<scalar_t>;
306 alignas(64) std::array<scalar_t, vec_t::size()> array_values{};
307 values.store(array_values.data());
308 store<StorePolicy>(data, stride, index, array_values);
309 }
310
311 /** Simultaneously sum over n rows at once
312
313 This algorithm calculates the sum without loss of precision over large axes. It
314 does this by chunking the sum into groups of 16 or more elements. The sums of
315 these chunks are also summed in chunks and so on until there is just a single sum
316 value remaining. This means only numbers of a similar order of magnitude are
317 added together, thus minimising rounding errors.
318
319 This is done in a single linear pass over the data and with O(1) extra storage.
320 A simplified recursive implementation would look like this:
321
322 scalar_t row_sum(const scalar_t * data, int64_t n) {
323 // Note, in practice the chunk size can increase with n
324 // This allows the recursion depth to be limited to O(1).
325 constexpr int64_t min_chunk_size = 16;
326
327 scalar_t sum = 0;
328 if (n <= min_chunk_size) {
329 // Recursive base case, calculate a simple running sum
330 for (const auto i : c10::irange(n)) {
331 sum += data[i];
332 }
333 return sum;
334 }
335
336 // Recursively sum larger chunks of elements
337 const int64_t chunk_size = std::max(divup(n, min_chunk_size), min_chunk_size);
338 for (int64_t i = 0; i < n; i += chunk_size) {
339 sum += row_sum(data + i, std::min(chunk_size, n - i));
340 }
341 return sum;
342 }
343 */
344 template <typename scalar_t, int64_t nrows, typename LoadPolicy>
multi_row_sum(const char * C10_RESTRICT in_data,const int64_t row_stride,const int64_t col_stride,const int64_t size)345 std::array<scalar_t, nrows> multi_row_sum(
346 const char * C10_RESTRICT in_data,
347 const int64_t row_stride,
348 const int64_t col_stride,
349 const int64_t size) {
350 constexpr int64_t num_levels = 4;
351
352 const int64_t level_power =
353 std::max(int64_t(4), utils::CeilLog2(size) / num_levels);
354 const int64_t level_step = (1 << level_power);
355 const int64_t level_mask = level_step - 1;
356
357 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
358 scalar_t acc[num_levels][nrows];
359 std::fill_n(&acc[0][0], num_levels * nrows, scalar_t(0));
360
361 int64_t i = 0;
362 for (; i + level_step <= size;) {
363 for (int64_t j = 0; j < level_step; ++j, ++i) {
364 const char * sum_base = in_data + i * row_stride;
365 #if !defined(COMPILING_FOR_MIN_SIZE)
366 # pragma unroll
367 #endif
368 for (const auto k : c10::irange(nrows)) {
369 acc[0][k] += LoadPolicy::load(sum_base, col_stride, k);
370 }
371 }
372
373 for (const auto j : c10::irange(1, num_levels)) {
374 #if !defined(COMPILING_FOR_MIN_SIZE)
375 # pragma unroll
376 #endif
377 for (const auto k : c10::irange(nrows)) {
378 acc[j][k] += acc[j-1][k];
379 acc[j-1][k] = scalar_t(0);
380 }
381
382 const auto mask = (level_mask << (j * level_power));
383 if ((i & mask) != 0) {
384 break;
385 }
386 }
387 }
388
389 for (; i < size; ++i) {
390 const char * sum_base = in_data + i * row_stride;
391 #if !defined(COMPILING_FOR_MIN_SIZE)
392 # pragma unroll
393 #endif
394 for (const auto k : c10::irange(nrows)) {
395 acc[0][k] += LoadPolicy::load(sum_base, col_stride, k);
396 }
397 }
398
399 for (const auto j : c10::irange(1, num_levels)) {
400 #if !defined(COMPILING_FOR_MIN_SIZE)
401 # pragma unroll
402 #endif
403 for (const auto k : c10::irange(nrows)) {
404 acc[0][k] += acc[j][k];
405 }
406 }
407
408 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
409 std::array<scalar_t, nrows> ret;
410 for (const auto k : c10::irange(nrows)) {
411 ret[k] = acc[0][k];
412 }
413 return ret;
414 }
415
416 template <typename scalar_t, typename LoadPolicy>
row_sum(const char * C10_RESTRICT in_data,const int64_t in_stride,const int64_t size)417 scalar_t row_sum(const char * C10_RESTRICT in_data,
418 const int64_t in_stride, const int64_t size) {
419 constexpr int64_t ilp_factor = 4;
420
421 // Interpret row as a (-1, ilp_factor) shaped array to find partial sums
422 const int64_t size_ilp = size / ilp_factor;
423 auto partial_sums = multi_row_sum<scalar_t, ilp_factor, LoadPolicy>(
424 in_data, in_stride * ilp_factor, in_stride, size_ilp);
425
426 for (int64_t i = size_ilp * ilp_factor; i < size; ++i) {
427 partial_sums[0] += LoadPolicy::load(in_data, in_stride, i);
428 }
429
430 for (const auto k : c10::irange(1, ilp_factor)) {
431 partial_sums[0] += partial_sums[k];
432 }
433
434 return partial_sums[0];
435 }
436
437 template <typename acc_t, typename VecLoadPolicy, typename ScalarLoadPolicy, typename StorePolicy>
vectorized_inner_sum(char * C10_RESTRICT data[2],int64_t outer_stride,int64_t out_stride,int64_t size0,int64_t size1)438 void vectorized_inner_sum(
439 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
440 char * C10_RESTRICT data[2], int64_t outer_stride, int64_t out_stride,
441 int64_t size0, int64_t size1) {
442 using vacc_t = Vectorized<acc_t>;
443 constexpr int64_t vec_stride = VecLoadPolicy::memsize();
444 constexpr int64_t scalar_stride = ScalarLoadPolicy::memsize();
445 constexpr int64_t vec_numel = vec_stride / scalar_stride;
446 const int64_t vec_size = size0 / vec_numel;
447
448 // Input is contiguous over the first (reduced) dimension
449 for (const auto j : c10::irange(size1)) {
450 const auto *row_in = data[1] + j * outer_stride;
451 auto vec_acc = row_sum<vacc_t, VecLoadPolicy>(row_in, vec_stride, vec_size);
452
453 acc_t final_acc = 0;
454 for (int64_t k = vec_size * vec_numel; k < size0; ++k) {
455 final_acc += ScalarLoadPolicy::load(row_in, scalar_stride, k);
456 }
457
458 alignas(64) std::array<acc_t, vacc_t::size()> partials{};
459 vec_acc.store(partials.data());
460 for (const auto k : c10::irange(partials.size())) {
461 final_acc += partials[k];
462 }
463 store<StorePolicy>(data[0], out_stride, j, final_acc);
464 }
465 }
466
467 template <typename acc_t, typename LoadPolicy, typename StorePolicy>
scalar_inner_sum(char * C10_RESTRICT data[2],int64_t in_strides[2],int64_t out_stride,int64_t size0,int64_t size1)468 void scalar_inner_sum(
469 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
470 char * C10_RESTRICT data[2], int64_t in_strides[2], int64_t out_stride,
471 int64_t size0, int64_t size1) {
472 for (const auto j : c10::irange(size1)) {
473 const auto *row_in = data[1] + j * in_strides[1];
474 auto ans = row_sum<acc_t, LoadPolicy>(row_in, in_strides[0], size0);
475 store<StorePolicy>(data[0], out_stride, j, ans);
476 }
477 }
478
479 template <typename acc_t, typename VecLoadPolicy, typename ScalarLoadPolicy, typename StorePolicy>
vectorized_outer_sum(char * C10_RESTRICT data[2],int64_t inner_stride,int64_t out_stride,int64_t size0,int64_t size1)480 void vectorized_outer_sum(
481 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
482 char * C10_RESTRICT data[2], int64_t inner_stride, int64_t out_stride,
483 int64_t size0, int64_t size1) {
484 using vacc_t = Vectorized<acc_t>;
485 constexpr int64_t scalar_stride = ScalarLoadPolicy::memsize();
486 constexpr int64_t vec_stride = VecLoadPolicy::memsize();
487 constexpr int64_t nrows = 4;
488
489 // Input is contiguous over the second (non-reduced) dimension
490 int64_t j = 0;
491 for (; j + nrows * vacc_t::size() <= size1; j += nrows * vacc_t::size()) {
492 const auto *row_in = data[1] + j * scalar_stride;
493 auto sums = multi_row_sum<vacc_t, nrows, VecLoadPolicy>(
494 row_in, inner_stride, vec_stride, size0);
495
496 for (const auto i : c10::irange(nrows)) {
497 const int64_t base_idx = j + i * vacc_t::size();
498 store<StorePolicy>(data[0], out_stride, base_idx, sums[i]);
499 }
500 }
501
502 for (; j + vacc_t::size() <= size1; j += vacc_t::size()) {
503 const auto *row_in = data[1] + j * scalar_stride;
504 const vacc_t sums = row_sum<vacc_t, VecLoadPolicy>(
505 row_in, inner_stride, size0);
506
507 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
508 store<StorePolicy>(data[0], out_stride, j, sums);
509 }
510
511 for (; j < size1; ++j) {
512 const auto *row_in = data[1] + j * scalar_stride;
513 auto ans = row_sum<acc_t, ScalarLoadPolicy>(row_in, inner_stride, size0);
514 store<StorePolicy>(data[0], out_stride, j, ans);
515 }
516 }
517
518 template <typename acc_t, typename LoadPolicy, typename StorePolicy>
scalar_outer_sum(char * C10_RESTRICT data[2],int64_t in_strides[2],int64_t out_stride,int64_t size0,int64_t size1)519 void scalar_outer_sum(
520 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
521 char * C10_RESTRICT data[2], int64_t in_strides[2], int64_t out_stride,
522 int64_t size0, int64_t size1) {
523 constexpr int64_t nrows = 4;
524 int64_t j = 0;
525 for (; j + (nrows - 1) < size1; j += nrows) {
526 const auto *row_in = data[1] + j * in_strides[1];
527 auto sums = multi_row_sum<acc_t, nrows, LoadPolicy>(
528 row_in, in_strides[0], in_strides[1], size0);
529 store<StorePolicy>(data[0], out_stride, j, sums);
530 }
531
532 for (; j < size1; ++j) {
533 const auto *row_in = data[1] + j * in_strides[1];
534 auto ans = row_sum<acc_t, LoadPolicy>(
535 row_in, in_strides[0], size0);
536 store<StorePolicy>(data[0], out_stride, j, ans);
537 }
538 }
539
540 // Custom floating point sum for better accuracy
541 template <bool ignore_nan, typename scalar_t>
cascade_sum(TensorIterator & iter)542 void cascade_sum(TensorIterator &iter) {
543 iter.output_base().fill_(scalar_t(0));
544 iter.parallel_reduce(
545 [&](char** data, const int64_t* strides, int64_t size0, int64_t size1) {
546 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
547 int64_t in_strides[] = { strides[1], strides[3] };
548 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
549 int64_t out_strides[] = { strides[0], strides[2] };
550
551 // Move reduction to be the 1st dim
552 if (out_strides[0] != 0 && out_strides[1] == 0) {
553 std::swap(in_strides[0], in_strides[1]);
554 std::swap(out_strides[0], out_strides[1]);
555 std::swap(size0, size1);
556 }
557
558 // Special case? - not a true reduction
559 if (out_strides[0] != 0 && out_strides[1] != 0) {
560 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
561 int64_t outer_strides[] = { strides[2], strides[3] };
562 UNARY_OUTER_LOOP(data, outer_strides, size1, [&] {
563 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
564 char* ptrs[3] = { data[0], data[0], data[1] };
565 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
566 int64_t inner_strides[3] = { strides[0], strides[0], strides[1] };
567 if constexpr (ignore_nan) {
568 basic_loop(ptrs, inner_strides, 0, size0, [](scalar_t a, scalar_t b) {
569 auto a_notnan = at::_isnan(a) ? scalar_t(0) : a;
570 auto b_notnan = at::_isnan(b) ? scalar_t(0) : b;
571 return a_notnan + b_notnan;
572 });
573 } else {
574 basic_loop(ptrs, inner_strides, 0, size0,
575 [](scalar_t a, scalar_t b) { return a + b; });
576 }
577 });
578 return;
579 }
580
581 const int64_t out_stride = out_strides[1];
582 TORCH_INTERNAL_ASSERT(out_strides[0] == 0);
583
584 using vec_t = Vectorized<scalar_t>;
585 using acc_t = at::acc_type<scalar_t, true>;
586 using vacc_t = Vectorized<acc_t>;
587 using ScalarLoadPolicy = std::conditional_t<
588 ignore_nan,
589 NanSumCastLoadPolicy<scalar_t, acc_t>,
590 CastLoadPolicy<scalar_t, acc_t>>;
591 using StorePolicy = CastStoreAccumulate<scalar_t, acc_t>;
592
593 if (in_strides[0] == sizeof(scalar_t) && size0 >= vec_t::size()) {
594 // Contiguous inner reduction
595 using VecLoadPolicy = std::conditional_t<
596 ignore_nan,
597 InnerNanSumCastLoadPolicy<vec_t, vacc_t>,
598 InnerSumCastLoadPolicy<vec_t, vacc_t>>;
599 vectorized_inner_sum<acc_t, VecLoadPolicy, ScalarLoadPolicy, StorePolicy>(
600 data, in_strides[1], out_stride, size0, size1);
601 } else if (in_strides[1] == sizeof(scalar_t) && size1 >= vec_t::size()) {
602 // Contiguous outer reduction
603 using VecLoadPolicy = std::conditional_t<
604 ignore_nan,
605 OuterNanSumCastLoadPolicy<vec_t, vacc_t>,
606 OuterSumCastLoadPolicy<vec_t, vacc_t>>;
607 vectorized_outer_sum<acc_t, VecLoadPolicy, ScalarLoadPolicy, StorePolicy>(
608 data, in_strides[0], out_stride, size0, size1);
609 } else if (in_strides[0] < in_strides[1]) {
610 scalar_inner_sum<acc_t, ScalarLoadPolicy, StorePolicy>(
611 data, in_strides, out_stride, size0, size1);
612 } else {
613 scalar_outer_sum<acc_t, ScalarLoadPolicy, StorePolicy>(
614 data, in_strides, out_stride, size0, size1);
615 }
616 });
617 }
618
sum_kernel_impl(TensorIterator & iter)619 void sum_kernel_impl(TensorIterator &iter) {
620 if (isIntegralType(iter.dtype(), /*includeBool=*/ true)) {
621 AT_DISPATCH_INTEGRAL_TYPES_AND(ScalarType::Bool, iter.dtype(), "sum_cpu",
622 [&] {
623 binary_kernel_reduce_vec(
624 iter, [=](scalar_t a, scalar_t b) -> scalar_t { return a + b; },
625 [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return a + b; });
626 });
627 return;
628 }
629
630 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
631 ScalarType::BFloat16, ScalarType::Half, iter.dtype(), "sum_cpu", [&] {
632 cascade_sum</*ignore_nan=*/false, scalar_t>(iter);
633 });
634 }
635
nansum_kernel_impl(TensorIterator & iter)636 void nansum_kernel_impl(TensorIterator &iter) {
637 AT_DISPATCH_FLOATING_TYPES_AND2(
638 ScalarType::BFloat16, ScalarType::Half, iter.dtype(), "nansum_cpu", [&] {
639 cascade_sum</*ignore_nan=*/true, scalar_t>(iter);
640 });
641 }
642
643 } // namespace (anonymous)
644
645 // nansum on Float16 has poor accuracy with AVX2, and more so with AVX512.
646 // So until it's fixed, it won't be dispatched with AVX512. GH issue 59415.
647 // Besides, these kernels are slower with AVX512 than with AVX2.
648 REGISTER_DISPATCH(nansum_stub, &nansum_kernel_impl);
649 REGISTER_DISPATCH(sum_stub, &sum_kernel_impl);
650
651 } // namespace at::native
652