xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/NonEmptyUtils.h>
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/TensorIterator.h>
5 #include <ATen/native/TensorAdvancedIndexing.h>
6 #include <ATen/core/Tensor.h>
7 #include <ATen/Config.h>
8 #include <ATen/Dispatch.h>
9 #include <ATen/NumericUtils.h>
10 #include <ATen/Parallel.h>
11 #include <ATen/native/cpu/ReduceUtils.h>
12 #include <ATen/cpu/vec/functional.h>
13 #include <ATen/cpu/vec/vec.h>
14 #include <c10/util/irange.h>
15 #ifdef USE_FBGEMM
16 #include <fbgemm/Utils.h>
17 #endif
18 #include <ATen/OpMathType.h>
19 
20 #ifndef AT_PER_OPERATOR_HEADERS
21 #include <ATen/Functions.h>
22 #include <ATen/NativeFunctions.h>
23 #else
24 #include <ATen/ops/empty.h>
25 #include <ATen/ops/zeros.h>
26 #endif
27 namespace at::native {
28 
29 namespace {
30 
31 // Implement as functors since lambdas don't get optimized.
32 class ReduceMultiply {
33 public:
34   template <typename scalar_t>
operator ()(at::opmath_type<scalar_t> * self_data,scalar_t * src_data) const35   constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
36     using opmath_t = at::opmath_type<scalar_t>;
37     *self_data *= opmath_t(*src_data);
38   }
39 
operator ()(bool * self_data,bool * src_data) const40   constexpr void operator() (bool * self_data, bool * src_data) const {
41     *self_data = *self_data && *src_data;
42   }
43 };
44 static ReduceMultiply reduce_multiply;
45 
46 class ReduceAdd {
47 public:
48   template <typename scalar_t>
operator ()(at::opmath_type<scalar_t> * self_data,scalar_t * src_data) const49   constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
50     using opmath_t = at::opmath_type<scalar_t>;
51     *self_data += opmath_t(*src_data);
52   }
53 };
54 static ReduceAdd reduce_add;
55 
56 class ReduceMean {
57 public:
58   template <typename scalar_t>
operator ()(at::opmath_type<scalar_t> * self_data,scalar_t * src_data) const59   constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
60     using opmath_t = at::opmath_type<scalar_t>;
61     *self_data += opmath_t(*src_data);
62   }
63 };
64 static ReduceMean reduce_mean;
65 
66 class ReduceMaximum {
67 public:
68   template <typename scalar_t>
operator ()(at::opmath_type<scalar_t> * self_data,scalar_t * src_data) const69   constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
70     using opmath_t = at::opmath_type<scalar_t>;
71     *self_data = at::_isnan<scalar_t>(*src_data) ? opmath_t(*src_data) : std::max(*self_data, opmath_t(*src_data));
72   }
73 };
74 static ReduceMaximum reduce_maximum;
75 
76 class ReduceMinimum {
77 public:
78   template <typename scalar_t>
operator ()(at::opmath_type<scalar_t> * self_data,scalar_t * src_data) const79   constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
80     using opmath_t = at::opmath_type<scalar_t>;
81     *self_data = at::_isnan<scalar_t>(*src_data) ? opmath_t(*src_data) : std::min(*self_data, opmath_t(*src_data));
82   }
83 };
84 static ReduceMinimum reduce_minimum;
85 
86 class TensorAssign {
87 public:
88   template <typename scalar_t>
operator ()(at::opmath_type<scalar_t> * self_data,scalar_t * src_data) const89   constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
90     using opmath_t = at::opmath_type<scalar_t>;
91     *self_data = opmath_t(*src_data);
92   }
93 };
94 static TensorAssign tensor_assign;
95 
96 template <bool is_scatter_like = true>
97 struct _cpu_scatter_gather_dim_loop {
98   template <typename scalar_t, typename func_t>
operator ()at::native::__anon2d3241d40111::_cpu_scatter_gather_dim_loop99   void operator()(
100     at::opmath_type<scalar_t>* self_data, int64_t self_dim_stride,
101     int64_t* index_data, int64_t index_dim_stride,
102     scalar_t* src_data, int64_t src_dim_stride,
103     int64_t dim, int64_t index_dim_size,
104     int64_t index_upper_bound,
105     func_t& f
106   ) {
107 
108     for (const auto i : c10::irange(index_dim_size)) {
109       int64_t idx_dim = index_data[i * index_dim_stride];
110       // we are not putting idx_dim in the error message because it disables
111       // loop optimization in clang-7
112       TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound,
113         "index ", index_data[i * index_dim_stride],
114         " is out of bounds for dimension ", dim,
115         " with size ", index_upper_bound
116       );
117 
118       f(
119         self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride,
120         src_data + (is_scatter_like ? i : idx_dim) * src_dim_stride
121       );
122     }
123   }
124 
125   template <typename scalar_t, typename func_t>
operator ()at::native::__anon2d3241d40111::_cpu_scatter_gather_dim_loop126   void operator()(
127     at::opmath_type<scalar_t>* self_data, int64_t self_dim_stride,
128     int64_t* index_data, int64_t index_dim_stride,
129     Scalar value,
130     int64_t dim, int64_t index_dim_size,
131     int64_t index_upper_bound,
132     func_t& f
133   ) {
134 
135     for (const auto i : c10::irange(index_dim_size)) {
136       int64_t idx_dim = index_data[i * index_dim_stride];
137       // we are not putting idx_dim in the error message because it disables
138       // loop optimization in clang-7
139       TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound,
140         "index ", index_data[i * index_dim_stride],
141         " is out of bounds for dimension ", dim,
142         " with size ", index_upper_bound
143       );
144       auto temp = value.to<scalar_t>();
145       f(
146         self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride, &temp
147       );
148     }
149   }
150 };
151 
create_acc_buffer(Tensor & buffer,const Tensor & self,bool need_acc)152 inline void create_acc_buffer(Tensor& buffer, const Tensor& self, bool need_acc) {
153   if (need_acc) {
154     auto acc_type = at::toOpMathType(self.scalar_type());
155     buffer = at::empty(self.sizes(), self.options().dtype(acc_type));
156     buffer.copy_(self);
157   } else {
158     buffer = self;
159   }
160 }
161 
162 template <bool is_scatter_like = true>
163 struct cpu_scatter_gather_base_kernel {
164   template <typename func_t>
operator ()at::native::__anon2d3241d40111::cpu_scatter_gather_base_kernel165   void operator()(const Tensor& self, int64_t dim,
166     const Tensor& index, const Scalar& value,
167     const std::string& method_name, func_t& kernel_func) {
168 
169     Tensor buffer;
170     bool need_acc = isReducedFloatingType(self.scalar_type());
171     create_acc_buffer(buffer, self, need_acc);
172 
173     auto index_sizes = ensure_nonempty_vec(index.sizes().vec());
174     auto index_strides = ensure_nonempty_vec(index.strides().vec());
175 
176     // `dim` is traversed in the kernel,
177     // that is why index.stride(dim) = 0 and index.size(dim) = 1.
178     // Also, index.size(dim) = 1 makes sure that TensorIterator.DimCounter
179     // has the following form : (i_1,..., i_{dim-1}, 0, i_{dim+1},...,i_n).
180     index_sizes[dim] = 1;
181     index_strides[dim] = 0;
182 
183     auto iter = TensorIteratorConfig()
184       .check_all_same_dtype(false)
185       .resize_outputs(false)
186       // NOLINTNEXTLINE(bugprone-argument-comment)
187       .declare_static_shape(index.sizes(), /*squash_dim=*/dim)
188       .add_output(buffer)
189       .add_const_input(index)
190       .build();
191 
192     auto self_dim_stride = ensure_nonempty_stride(buffer, dim);
193     auto self_dim_size = ensure_nonempty_size(buffer, dim);
194 
195     auto index_dim_stride = ensure_nonempty_stride(index, dim);
196     auto index_dim_size = ensure_nonempty_size(index, dim);
197 
198     auto index_upper_bound = self_dim_size;
199 
200     // since the index dimension is squashed, need to alter the grain size according
201     // to keep equal granularity in parallelism.
202     int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / index_dim_size);
203 
204     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
205       ScalarType::Bool, ScalarType::Half, ScalarType::BFloat16, self.scalar_type(),
206       "scatter_gather_scalar_cpu", [&] {
207         constexpr auto SELF_ITER_STRIDE_IDX = 0;
208         constexpr auto INDEX_ITER_STRIDE_IDX = 1;
209         using opmath_t = at::opmath_type<scalar_t>;
210         _cpu_scatter_gather_dim_loop<is_scatter_like> loop_func;
211         auto loop = [&](char** data, const int64_t* strides, int64_t n) {
212           auto* self_data_bytes = data[SELF_ITER_STRIDE_IDX];
213           auto* index_data_bytes = data[INDEX_ITER_STRIDE_IDX];
214           // we change the order of TensorIterator-dim loop
215           // vs dim-TensorIterator loop order depending on
216           // whether dim is the last dimension
217           if (dim== buffer.dim() - 1) {
218             for (const auto nelem C10_UNUSED : c10::irange(n)) {
219               // dim loop is a separate code block
220               // for better performance
221               loop_func.template operator()<scalar_t, func_t>(
222                 (opmath_t*)self_data_bytes, self_dim_stride,
223                 (int64_t*)index_data_bytes, index_dim_stride,
224                 value, dim, index_dim_size, index_upper_bound,
225                 kernel_func);
226 
227               self_data_bytes += strides[SELF_ITER_STRIDE_IDX];
228               index_data_bytes += strides[INDEX_ITER_STRIDE_IDX];
229             }
230           }
231           else {
232             for (const auto i : c10::irange(index_dim_size)) {
233               auto* self_data = self_data_bytes;
234               auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride);
235               for (const auto nelem C10_UNUSED : c10::irange(n)) {
236                 int64_t idx_dim = *(int64_t*)index_data;
237                 // we are not putting idx_dim in the error message because it disables
238                 // loop optimization in clang-7
239                 TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound,
240                             "index ", *(int64_t*)index_data,
241                             " is out of bounds for dimension ", dim,
242                             " with size ", index_upper_bound);
243 
244                 auto temp = value.to<scalar_t>();
245                 kernel_func((opmath_t*)self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride, &temp);
246 
247                 self_data += strides[SELF_ITER_STRIDE_IDX];
248                 index_data += strides[INDEX_ITER_STRIDE_IDX];
249               }
250             }
251           }
252         };
253         iter.for_each(loop, grain_size);
254       }
255     );
256     if (need_acc) {
257       self.copy_(buffer);
258     }
259   }
260 
261   template <typename func_t>
operator ()at::native::__anon2d3241d40111::cpu_scatter_gather_base_kernel262   void operator()(const Tensor& self, int64_t dim,
263     const Tensor& index, const Tensor& src,
264     const std::string& method_name, func_t& kernel_func) {
265 
266     Tensor buffer;
267     bool need_acc = isReducedFloatingType(self.scalar_type());
268     create_acc_buffer(buffer, self, need_acc);
269 
270     auto iter = TensorIteratorConfig()
271       .check_all_same_dtype(false)
272       .resize_outputs(false)
273       // NOLINTNEXTLINE(bugprone-argument-comment)
274       .declare_static_shape(index.sizes(), /*squash_dim=*/dim)
275       .add_output(buffer)
276       .add_const_input(src)
277       .add_const_input(index)
278       .build();
279 
280     auto self_dim_stride = ensure_nonempty_stride(buffer, dim);
281     auto self_dim_size = ensure_nonempty_size(buffer, dim);
282 
283     auto index_dim_stride = ensure_nonempty_stride(index, dim);
284     auto index_dim_size = ensure_nonempty_size(index, dim);
285 
286     auto src_dim_stride = ensure_nonempty_stride(src, dim);
287     auto src_dim_size = ensure_nonempty_size(src, dim);
288 
289     auto index_upper_bound = is_scatter_like ? self_dim_size : src_dim_size;
290 
291     int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / index_dim_size);
292 
293     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
294       ScalarType::Bool, ScalarType::Half, ScalarType::BFloat16, iter.dtype(1),
295       "scatter_gather_tensor_cpu", [&] {
296         constexpr auto SELF_ITER_STRIDE_IDX = 0;
297         constexpr auto INDEX_ITER_STRIDE_IDX = 2;
298         constexpr auto SRC_ITER_STRIDE_IDX = 1;
299         using opmath_t = at::opmath_type<scalar_t>;
300         _cpu_scatter_gather_dim_loop<is_scatter_like> loop_func;
301         auto loop = [&](char** data, const int64_t* strides, int64_t n) {
302           auto* self_data_bytes = data[SELF_ITER_STRIDE_IDX];
303           auto* index_data_bytes = data[INDEX_ITER_STRIDE_IDX];
304           auto* src_data_bytes = data[SRC_ITER_STRIDE_IDX];
305           // we change the order of TensorIterator-dim loop
306           // vs dim-TensorIterator loop order depending on
307           // whether dim is the last dimension
308           if (dim== buffer.dim() - 1) {
309             for (const auto nelem C10_UNUSED : c10::irange(n)) {
310               // dim loop is a separate code block
311               // for better performance
312               loop_func.template operator()<scalar_t, func_t>(
313                  (opmath_t*)self_data_bytes, self_dim_stride,
314                  (int64_t*)index_data_bytes, index_dim_stride,
315                  (scalar_t*)src_data_bytes, src_dim_stride,
316                  dim, index_dim_size, index_upper_bound,
317                  kernel_func
318                );
319 
320               self_data_bytes += strides[SELF_ITER_STRIDE_IDX];
321               index_data_bytes += strides[INDEX_ITER_STRIDE_IDX];
322               src_data_bytes += strides[SRC_ITER_STRIDE_IDX];
323             }
324           }
325           else {
326             for (const auto i : c10::irange(index_dim_size)) {
327               auto* self_data = self_data_bytes;
328               auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride);
329               auto* src_data = src_data_bytes;
330               for (const auto nelem C10_UNUSED : c10::irange(n)) {
331                 int64_t idx_dim = *(int64_t*)index_data;
332                 // we are not putting idx_dim in the error message because it disables
333                 // loop optimization in clang-7
334                 TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound,
335                             "index ", *(int64_t*)index_data,
336                             " is out of bounds for dimension ", dim,
337                             " with size ", index_upper_bound);
338 
339                 kernel_func(
340                   (opmath_t*)self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride,
341                   (scalar_t*)src_data + (is_scatter_like ? i : idx_dim) * src_dim_stride);
342 
343                 self_data += strides[SELF_ITER_STRIDE_IDX];
344                 index_data += strides[INDEX_ITER_STRIDE_IDX];
345                 src_data += strides[SRC_ITER_STRIDE_IDX];
346               }
347             }
348           }
349         };
350         iter.for_each(loop, grain_size);
351       }
352     );
353     if (need_acc) {
354       self.copy_(buffer);
355     }
356   }
357 
operator ()at::native::__anon2d3241d40111::cpu_scatter_gather_base_kernel358   void operator()(const Tensor& self, int64_t dim,
359     const Tensor& index, const Tensor& src,
360     const std::string& method_name, ReduceMean& kernel_func) {
361 
362     Tensor buffer;
363     bool need_acc = isReducedFloatingType(self.scalar_type());
364     create_acc_buffer(buffer, self, need_acc);
365 
366     auto iter = TensorIteratorConfig()
367       .check_all_same_dtype(false)
368       .resize_outputs(false)
369       // NOLINTNEXTLINE(bugprone-argument-comment)
370       .declare_static_shape(index.sizes(), /*squash_dim=*/dim)
371       .add_output(buffer)
372       .add_const_input(src)
373       .add_const_input(index)
374       .build();
375 
376     auto self_dim_stride = ensure_nonempty_stride(buffer, dim);
377     auto self_dim_size = ensure_nonempty_size(buffer, dim);
378 
379     auto index_dim_stride = ensure_nonempty_stride(index, dim);
380     auto index_dim_size = ensure_nonempty_size(index, dim);
381 
382     auto src_dim_stride = ensure_nonempty_stride(src, dim);
383     auto src_dim_size = ensure_nonempty_size(src, dim);
384 
385     auto index_upper_bound = is_scatter_like ? self_dim_size : src_dim_size;
386 
387     int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / index_dim_size);
388 
389     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
390       ScalarType::Half, ScalarType::BFloat16, iter.dtype(1),
391       "scatter_gather_tensor_cpu_reduce_mean", [&] {
392         constexpr auto SELF_ITER_STRIDE_IDX = 0;
393         constexpr auto INDEX_ITER_STRIDE_IDX = 2;
394         constexpr auto SRC_ITER_STRIDE_IDX = 1;
395         using opmath_t = at::opmath_type<scalar_t>;
396         _cpu_scatter_gather_dim_loop<is_scatter_like> loop_func;
397         auto loop = [&](char** data, const int64_t* strides, int64_t n) {
398           auto* self_data_bytes = data[SELF_ITER_STRIDE_IDX];
399           auto* index_data_bytes = data[INDEX_ITER_STRIDE_IDX];
400           auto* src_data_bytes = data[SRC_ITER_STRIDE_IDX];
401           // we change the order of TensorIterator-dim loop
402           // vs dim-TensorIterator loop order depending on
403           // whether dim is the last dimension
404           if (dim== buffer.dim() - 1) {
405             for (const auto nelem C10_UNUSED : c10::irange(n)) {
406               // dim loop is a separate code block
407               // for better performance
408               loop_func.template operator()<scalar_t, ReduceMean>(
409                  (opmath_t*)self_data_bytes, self_dim_stride,
410                  (int64_t*)index_data_bytes, index_dim_stride,
411                  (scalar_t*)src_data_bytes, src_dim_stride,
412                  dim, index_dim_size, index_upper_bound,
413                  kernel_func
414                );
415 
416               self_data_bytes += strides[SELF_ITER_STRIDE_IDX];
417               index_data_bytes += strides[INDEX_ITER_STRIDE_IDX];
418               src_data_bytes += strides[SRC_ITER_STRIDE_IDX];
419             }
420           }
421           else {
422             for (const auto i : c10::irange(index_dim_size)) {
423               auto* self_data = self_data_bytes;
424               auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride);
425               auto* src_data = src_data_bytes;
426               for (const auto nelem C10_UNUSED : c10::irange(n)) {
427                 int64_t idx_dim = *(int64_t*)index_data;
428                 // we are not putting idx_dim in the error message because it disables
429                 // loop optimization in clang-7
430                 TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound,
431                             "index ", *(int64_t*)index_data,
432                             " is out of bounds for dimension ", dim,
433                             " with size ", index_upper_bound);
434 
435                 kernel_func(
436                   (opmath_t*)self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride,
437                   (scalar_t*)src_data + (is_scatter_like ? i : idx_dim) * src_dim_stride);
438 
439                 self_data += strides[SELF_ITER_STRIDE_IDX];
440                 index_data += strides[INDEX_ITER_STRIDE_IDX];
441                 src_data += strides[SRC_ITER_STRIDE_IDX];
442               }
443             }
444           }
445         };
446         iter.for_each(loop, grain_size);
447       }
448     );
449     if (need_acc) {
450       self.copy_(buffer);
451     }
452   }
453 
operator ()at::native::__anon2d3241d40111::cpu_scatter_gather_base_kernel454   void operator()(const Tensor& self, int64_t dim,
455     const Tensor& index, const Tensor& src,
456     const std::string& method_name, ReduceMaximum& kernel_func) {
457     Tensor buffer;
458     bool need_acc = isReducedFloatingType(self.scalar_type());
459     create_acc_buffer(buffer, self, need_acc);
460 
461     auto iter = TensorIteratorConfig()
462       .check_all_same_dtype(false)
463       .resize_outputs(false)
464       // NOLINTNEXTLINE(bugprone-argument-comment)
465       .declare_static_shape(index.sizes(), /*squash_dim=*/dim)
466       .add_output(buffer)
467       .add_const_input(src)
468       .add_const_input(index)
469       .build();
470 
471     auto self_dim_stride = ensure_nonempty_stride(buffer, dim);
472     auto self_dim_size = ensure_nonempty_size(buffer, dim);
473 
474     auto index_dim_stride = ensure_nonempty_stride(index, dim);
475     auto index_dim_size = ensure_nonempty_size(index, dim);
476 
477     auto src_dim_stride = ensure_nonempty_stride(src, dim);
478     auto src_dim_size = ensure_nonempty_size(src, dim);
479 
480     auto index_upper_bound = is_scatter_like ? self_dim_size : src_dim_size;
481 
482     int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / index_dim_size);
483 
484     AT_DISPATCH_ALL_TYPES_AND3(
485       ScalarType::Bool, ScalarType::Half, ScalarType::BFloat16, iter.dtype(1),
486       "scatter_gather_tensor_cpu_reduce_amax", [&] {
487         constexpr auto SELF_ITER_STRIDE_IDX = 0;
488         constexpr auto INDEX_ITER_STRIDE_IDX = 2;
489         constexpr auto SRC_ITER_STRIDE_IDX = 1;
490         using opmath_t = at::opmath_type<scalar_t>;
491         _cpu_scatter_gather_dim_loop<is_scatter_like> loop_func;
492         auto loop = [&](char** data, const int64_t* strides, int64_t n) {
493           auto* self_data_bytes = data[SELF_ITER_STRIDE_IDX];
494           auto* index_data_bytes = data[INDEX_ITER_STRIDE_IDX];
495           auto* src_data_bytes = data[SRC_ITER_STRIDE_IDX];
496           // we change the order of TensorIterator-dim loop
497           // vs dim-TensorIterator loop order depending on
498           // whether dim is the last dimension
499           if (dim== buffer.dim() - 1) {
500             for (const auto nelem C10_UNUSED : c10::irange(n)) {
501               // dim loop is a separate code block
502               // for better performance
503               loop_func.template operator()<scalar_t, ReduceMaximum>(
504                  (opmath_t*)self_data_bytes, self_dim_stride,
505                  (int64_t*)index_data_bytes, index_dim_stride,
506                  (scalar_t*)src_data_bytes, src_dim_stride,
507                  dim, index_dim_size, index_upper_bound,
508                  kernel_func
509                );
510 
511               self_data_bytes += strides[SELF_ITER_STRIDE_IDX];
512               index_data_bytes += strides[INDEX_ITER_STRIDE_IDX];
513               src_data_bytes += strides[SRC_ITER_STRIDE_IDX];
514             }
515           }
516           else {
517             for (const auto i : c10::irange(index_dim_size)) {
518               auto* self_data = self_data_bytes;
519               auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride);
520               auto* src_data = src_data_bytes;
521               for (const auto nelem C10_UNUSED : c10::irange(n)) {
522                 int64_t idx_dim = *(int64_t*)index_data;
523                 // we are not putting idx_dim in the error message because it disables
524                 // loop optimization in clang-7
525                 TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound,
526                             "index ", *(int64_t*)index_data,
527                             " is out of bounds for dimension ", dim,
528                             " with size ", index_upper_bound);
529 
530                 kernel_func(
531                   (opmath_t*)self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride,
532                   (scalar_t*)src_data + (is_scatter_like ? i : idx_dim) * src_dim_stride);
533 
534                 self_data += strides[SELF_ITER_STRIDE_IDX];
535                 index_data += strides[INDEX_ITER_STRIDE_IDX];
536                 src_data += strides[SRC_ITER_STRIDE_IDX];
537               }
538             }
539           }
540         };
541         iter.for_each(loop, grain_size);
542       }
543     );
544     if (need_acc) {
545       self.copy_(buffer);
546     }
547   }
548 
operator ()at::native::__anon2d3241d40111::cpu_scatter_gather_base_kernel549   void operator()(const Tensor& self, int64_t dim,
550     const Tensor& index, const Tensor& src,
551     const std::string& method_name, ReduceMinimum& kernel_func) {
552 
553     Tensor buffer;
554     bool need_acc = isReducedFloatingType(self.scalar_type());
555     create_acc_buffer(buffer, self, need_acc);
556 
557     auto iter = TensorIteratorConfig()
558       .check_all_same_dtype(false)
559       .resize_outputs(false)
560       // NOLINTNEXTLINE(bugprone-argument-comment)
561       .declare_static_shape(index.sizes(), /*squash_dim=*/dim)
562       .add_output(buffer)
563       .add_const_input(src)
564       .add_const_input(index)
565       .build();
566 
567     auto self_dim_stride = ensure_nonempty_stride(buffer, dim);
568     auto self_dim_size = ensure_nonempty_size(buffer, dim);
569 
570     auto index_dim_stride = ensure_nonempty_stride(index, dim);
571     auto index_dim_size = ensure_nonempty_size(index, dim);
572 
573     auto src_dim_stride = ensure_nonempty_stride(src, dim);
574     auto src_dim_size = ensure_nonempty_size(src, dim);
575 
576     auto index_upper_bound = is_scatter_like ? self_dim_size : src_dim_size;
577 
578     int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / index_dim_size);
579 
580     AT_DISPATCH_ALL_TYPES_AND3(
581       ScalarType::Bool, ScalarType::Half, ScalarType::BFloat16, iter.dtype(1),
582       "scatter_gather_tensor_cpu_reduce_amin", [&] {
583         constexpr auto SELF_ITER_STRIDE_IDX = 0;
584         constexpr auto INDEX_ITER_STRIDE_IDX = 2;
585         constexpr auto SRC_ITER_STRIDE_IDX = 1;
586         using opmath_t = at::opmath_type<scalar_t>;
587         _cpu_scatter_gather_dim_loop<is_scatter_like> loop_func;
588         auto loop = [&](char** data, const int64_t* strides, int64_t n) {
589           auto* self_data_bytes = data[SELF_ITER_STRIDE_IDX];
590           auto* index_data_bytes = data[INDEX_ITER_STRIDE_IDX];
591           auto* src_data_bytes = data[SRC_ITER_STRIDE_IDX];
592           // we change the order of TensorIterator-dim loop
593           // vs dim-TensorIterator loop order depending on
594           // whether dim is the last dimension
595           if (dim== buffer.dim() - 1) {
596             for (const auto nelem C10_UNUSED : c10::irange(n)) {
597               // dim loop is a separate code block
598               // for better performance
599               loop_func.template operator()<scalar_t, ReduceMinimum>(
600                  (opmath_t*)self_data_bytes, self_dim_stride,
601                  (int64_t*)index_data_bytes, index_dim_stride,
602                  (scalar_t*)src_data_bytes, src_dim_stride,
603                  dim, index_dim_size, index_upper_bound,
604                  kernel_func
605                );
606 
607               self_data_bytes += strides[SELF_ITER_STRIDE_IDX];
608               index_data_bytes += strides[INDEX_ITER_STRIDE_IDX];
609               src_data_bytes += strides[SRC_ITER_STRIDE_IDX];
610             }
611           }
612           else {
613             for (const auto i : c10::irange(index_dim_size)) {
614               auto* self_data = self_data_bytes;
615               auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride);
616               auto* src_data = src_data_bytes;
617               for (const auto nelem C10_UNUSED : c10::irange(n)) {
618                 int64_t idx_dim = *(int64_t*)index_data;
619                 // we are not putting idx_dim in the error message because it disables
620                 // loop optimization in clang-7
621                 TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound,
622                             "index ", *(int64_t*)index_data,
623                             " is out of bounds for dimension ", dim,
624                             " with size ", index_upper_bound);
625 
626                 kernel_func(
627                   (opmath_t*)self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride,
628                   (scalar_t*)src_data + (is_scatter_like ? i : idx_dim) * src_dim_stride);
629 
630                 self_data += strides[SELF_ITER_STRIDE_IDX];
631                 index_data += strides[INDEX_ITER_STRIDE_IDX];
632                 src_data += strides[SRC_ITER_STRIDE_IDX];
633               }
634             }
635           }
636         };
637         iter.for_each(loop, grain_size);
638       }
639     );
640     if (need_acc) {
641       self.copy_(buffer);
642     }
643   }
644 };
645 
646 #ifndef USE_FBGEMM
647 namespace fbgemm {
648 
649 template <typename K, typename V>
radix_sort_parallel(K * const inp_key_buf,V * const inp_value_buf,K * const tmp_key_buf,V * const tmp_value_buf,const int64_t elements_count,const int64_t max_value)650 std::pair<K*, V*> radix_sort_parallel(
651     K* const inp_key_buf,
652     V* const inp_value_buf,
653     K* const tmp_key_buf,
654     V* const tmp_value_buf,
655     const int64_t elements_count,
656     const int64_t max_value) {
657   TORCH_INTERNAL_ASSERT(false, "radix_sort_parallel: ATen not compiled with FBGEMM support");
658   return std::make_pair(nullptr, nullptr);
659 }
660 
661 }
662 #endif
663 
664 // Note [scatter reduce optimization]
665 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
666 //
667 // 1. initiative: optimize `scatter_reduce` on classic PyG use-case:
668 //   `scatter_reduce` is extensively used on 'message passing' when
669 //   aggregating info.
670 //
671 //   Typically, `self` will 2D tensor and `index` is a 1D extended/broadcasted
672 //   tensor, which means that the aggregation is on rowwise and we can vectorize
673 //   on the inner dimensions.
674 //
675 // 2. implementation: map `scatter_reduce` to `spmm` reduce
676 //   in the shape of `[M, N]` * `[N, K]`, where:
677 //
678 //   M: self_dim_size
679 //   nnz: index_dim_size
680 //   K: index.numel() / index_dim_size;
681 //
682 //   step 1: convert input index to CSR format (use radix_sort to
683 //     solve write addr conflicts on `self` tensor)
684 //
685 //   step 2: spmm reduce, parallel on M and vectorize on K
686 //
687 
688 template <typename scalar_t, ReductionType reduce>
cpu_scatter_reduce_expanded_index(const Tensor & self,const Tensor & index,const Tensor & src,bool include_self)689 void cpu_scatter_reduce_expanded_index(const Tensor& self, const Tensor& index, const Tensor& src, bool include_self) {
690   const int64_t* index_data = index.const_data_ptr<int64_t>();
691   scalar_t* self_data = self.data_ptr<scalar_t>();
692   const scalar_t* src_data = src.const_data_ptr<scalar_t>();
693 
694   const int64_t M = ensure_nonempty_size(self, 0);
695   const int64_t nnz = ensure_nonempty_size(index, 0);
696   const int64_t K = index.numel() / nnz;
697 
698   const int64_t index_upper_bound = M;
699 
700   auto keys = std::make_unique<int64_t[]>(nnz);
701   auto values = std::make_unique<int64_t[]>(nnz);
702   auto keys_tmp = std::make_unique<int64_t[]>(nnz);
703   auto values_tmp = std::make_unique<int64_t[]>(nnz);
704   at::parallel_for(0, nnz, 1, [&](int64_t begin, int64_t end) {
705     for (const auto i : c10::irange(begin, end)) {
706       int64_t index = index_data[i];
707       TORCH_CHECK(index >= 0 && index < index_upper_bound,
708                   "index ", index,
709                   " is out of bounds for dimension ", 0,
710                   " with size ", index_upper_bound);
711       keys[i] = index;
712       values[i] = i;
713     }
714   });
715 
716   int64_t* sorted_col_index_keys = nullptr;
717   int64_t* sorted_col_index_values = nullptr;
718   std::tie(sorted_col_index_keys, sorted_col_index_values) = fbgemm::radix_sort_parallel(
719       keys.get(),
720       values.get(),
721       keys_tmp.get(),
722       values_tmp.get(),
723       nnz,
724       M);
725 
726   int num_threads = at::get_num_threads();
727   std::vector<int64_t> num_uniq(num_threads, 0);
728   at::parallel_for(1, nnz, 1, [&](int64_t begin, int64_t end) {
729     int tid = at::get_thread_num();
730     for(const auto i : c10::irange(begin, end)) {
731       if (sorted_col_index_keys[i] != sorted_col_index_keys[i - 1]) {
732         num_uniq[tid]++;
733       }
734     }
735   });
736   num_uniq[0]++;
737   for (const auto n : c10::irange(1, num_threads)) {
738     num_uniq[n] += num_uniq[n - 1];
739   }
740 
741   // in case some rows are not written into, num_nonzero_rows will be smaller than M
742   int64_t num_nonzero_rows = num_uniq[num_threads - 1];
743   auto row_index_tmp = std::make_unique<int64_t[]>(num_nonzero_rows);
744   auto row_index_offset_tmp = std::make_unique<int64_t[]>(num_nonzero_rows + 1);
745   int64_t* row_index = row_index_tmp.get();
746   int64_t* row_index_offset = row_index_offset_tmp.get();
747   row_index[0] = sorted_col_index_keys[0];
748   row_index_offset[0] = 0;
749   row_index_offset[num_nonzero_rows] = nnz;
750 
751   at::parallel_for(1, nnz, 1, [&](int64_t begin, int64_t end) {
752     int tid = at::get_thread_num();
753     int64_t* t_index = row_index + ((tid == 0) ? 1 : num_uniq[tid - 1]);
754     int64_t* t_index_offset = row_index_offset + ((tid == 0) ? 1 : num_uniq[tid - 1]);
755     for (const auto i : c10::irange(begin, end)) {
756       if (sorted_col_index_keys[i] != sorted_col_index_keys[i - 1]) {
757         *t_index = sorted_col_index_keys[i];
758         *t_index_offset = i;
759         t_index++;
760         t_index_offset++;
761       }
762     }
763   });
764 
765   using opmath_t = at::opmath_type<scalar_t>;
766   Tensor buffer;
767   opmath_t* buffer_data = nullptr;
768   static constexpr bool need_acc = is_reduced_floating_point_v<scalar_t>;
769   if constexpr (need_acc) {
770     auto acc_type = at::toAccumulateType(self.scalar_type(), /*is_cuda=*/true);
771     buffer = at::zeros({num_threads, K}, self.options().dtype(acc_type));
772     buffer_data = buffer.data_ptr<opmath_t>();
773   }
774 
775   // TODO: do blocking on col dimension to reduce WR bandwidth
776   at::parallel_for(0, num_nonzero_rows, 1, [&](int64_t begin, int64_t end) {
777     int tid = at::get_thread_num();
778     TORCH_CHECK(tid < num_threads,
779                 "expect thread id smaller than ", num_threads, ", got thread id ", tid);
780     opmath_t* buffer_ptr = nullptr;
781 
782     for (const auto m : c10::irange(begin, end)) {
783       int64_t row = row_index[m];
784       int64_t off_start = row_index_offset[m];
785       int64_t off_end = row_index_offset[m + 1];
786       scalar_t* self_ptr = self_data + row * K;
787       if constexpr (need_acc) {
788         buffer_ptr = buffer_data + tid * K;
789       } else {
790         buffer_ptr = reinterpret_cast<opmath_t*>(self_ptr);
791       }
792 
793       // step 1: reinit rows in `self` if needed
794       _init<scalar_t, reduce>(self_ptr, buffer_ptr, K, include_self);
795 
796       // step 2: reduce
797       for (const auto n : c10::irange(off_start, off_end)) {
798         int64_t col = sorted_col_index_values[n];
799         update<scalar_t, reduce>(buffer_ptr, src_data + col * K, K);
800       }
801       if constexpr (need_acc) {
802         vec::convert(buffer_ptr, self_ptr, K);
803       }
804 
805       // step 3: finalize
806       int64_t count = include_self ? 1 : 0;
807       count += off_end - off_start;
808       write<scalar_t, reduce>(self_ptr, count, K);
809     }
810   });
811 }
812 
813 template <typename scalar_t>
cpu_gather_expanded_index_kernel(const Tensor & result,const Tensor & index,const Tensor & self)814 void cpu_gather_expanded_index_kernel(const Tensor& result, const Tensor& index, const Tensor& self) {
815   const int64_t* index_data = index.const_data_ptr<int64_t>();
816   scalar_t* result_data = result.data_ptr<scalar_t>();
817   const scalar_t* self_data = self.const_data_ptr<scalar_t>();
818 
819   const int64_t M = ensure_nonempty_size(result, 0);
820   const int64_t N = ensure_nonempty_size(self, 0);
821   const int64_t K = index.numel() / M;
822 
823   const int64_t index_upper_bound = N;
824 
825   using Vec = vec::Vectorized<scalar_t>;
826   int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / K);
827   at::parallel_for(0, M, grain_size, [&](int64_t begin, int64_t end) {
828     for (const auto m : c10::irange(begin, end)) {
829       scalar_t* result_ptr = result_data + m * K;
830       int64_t index = index_data[m];
831       TORCH_CHECK(index >= 0 && index < index_upper_bound,
832                   "index ", index,
833                   " is out of bounds for dimension ", 0,
834                   " with size ", index_upper_bound);
835       const scalar_t* self_ptr = self_data + index * K;
836       int64_t d = 0;
837       for (; d < K - (K % Vec::size()); d += Vec::size()) {
838         Vec out_vec = Vec::loadu(self_ptr + d);
839         out_vec.store(result_ptr + d);
840       }
841       #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
842       # pragma unroll
843       #endif
844       for (; d < K; d++) {
845         result_ptr[d] = self_ptr[d];
846       }
847     }
848   });
849 }
850 
scatter_add_expanded_index_kernel(const Tensor & self,const Tensor & index,const Tensor & src)851 void scatter_add_expanded_index_kernel(const Tensor& self, const Tensor& index, const Tensor& src) {
852   AT_DISPATCH_FLOATING_TYPES_AND2(
853     ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "scatter_add_expanded_index", [&] {
854       cpu_scatter_reduce_expanded_index<scalar_t, ReductionType::SUM>(self, index, src, /*include_self*/true);
855   });
856 }
857 
scatter_reduce_expanded_index_kernel(const Tensor & self,const Tensor & index,const Tensor & src,const ReductionType & reduction,bool include_self)858 void scatter_reduce_expanded_index_kernel(
859     const Tensor& self, const Tensor& index, const Tensor& src,
860     const ReductionType& reduction, bool include_self) {
861   AT_DISPATCH_FLOATING_TYPES_AND2(
862     ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "scatter_reduce_expanded_index", [&] {
863     AT_DISPATCH_REDUCTION_TYPES(reduction, [&]() {
864       cpu_scatter_reduce_expanded_index<scalar_t, reduce>(self, index, src, include_self);
865     });
866   });
867 }
868 
gather_expanded_index_kernel(const Tensor & result,const Tensor & self,const Tensor & index)869 void gather_expanded_index_kernel(const Tensor& result, const Tensor& self, const Tensor& index) {
870   AT_DISPATCH_FLOATING_TYPES_AND2(
871     ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "gather_expanded_index", [&] {
872       cpu_gather_expanded_index_kernel<scalar_t>(result, index, self);
873   });
874 }
875 
gather_cpu_kernel(const Tensor & result,const Tensor & self,int64_t dim,const Tensor & index)876 void gather_cpu_kernel(const Tensor& result, const Tensor& self, int64_t dim, const Tensor& index) {
877   cpu_scatter_gather_base_kernel</*is_scatter_like=*/false>()(
878     result, dim, index, self,
879     "gather_out_cpu", tensor_assign);
880 }
881 
scatter_cpu_kernel(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & src)882 void scatter_cpu_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) {
883   cpu_scatter_gather_base_kernel<>()(
884     self, dim, index, src, "scatter_cpu_", tensor_assign);
885 }
886 
scatter_fill_cpu_kernel(const Tensor & self,int64_t dim,const Tensor & index,const Scalar & value)887 void scatter_fill_cpu_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& value) {
888   cpu_scatter_gather_base_kernel<>()(
889     self, dim, index, value, "scatter_fill_cpu_", tensor_assign);
890 }
891 
scatter_add_cpu_kernel(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & src)892 void scatter_add_cpu_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) {
893   cpu_scatter_gather_base_kernel<>()(
894     self, dim, index, src,
895     "scatter_add_", reduce_add);
896 }
897 
scatter_reduce_cpu_kernel(const Tensor & self,const int64_t dim,const Tensor & index,const Tensor & src,const ReductionType & reduce)898 void scatter_reduce_cpu_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
899                                const Tensor& src, const ReductionType& reduce) {
900   switch (reduce) {
901   case ReductionType::SUM :
902     cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
903                                        "scatter_reduce_add_", reduce_add);
904     break;
905   case ReductionType::PROD :
906     cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
907                                        "scatter_reduce_multiply_", reduce_multiply);
908     break;
909   default :
910     break;
911   }
912 }
913 
scatter_reduce_two_cpu_kernel(const Tensor & self,const int64_t dim,const Tensor & index,const Tensor & src,const ReductionType & reduce)914 void scatter_reduce_two_cpu_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
915                                    const Tensor& src, const ReductionType& reduce) {
916   switch (reduce) {
917   case ReductionType::SUM :
918     cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
919                                        "scatter_reduce_sum_", reduce_add);
920     break;
921   case ReductionType::PROD :
922     cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
923                                        "scatter_reduce_prod_", reduce_multiply);
924     break;
925   case ReductionType::MAX :
926     cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
927                                        "scatter_reduce_amax_", reduce_maximum);
928     break;
929   case ReductionType::MIN :
930     cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
931                                        "scatter_reduce_amin_", reduce_minimum);
932     break;
933   case ReductionType::MEAN :
934     cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
935                                        "scatter_reduce_mean_", reduce_mean);
936     break;
937   }
938 }
939 
scatter_scalar_reduce_cpu_kernel(const Tensor & self,const int64_t dim,const Tensor & index,const Scalar & value,const ReductionType & reduce)940 void scatter_scalar_reduce_cpu_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
941                                       const Scalar& value, const ReductionType& reduce) {
942   switch (reduce) {
943   case ReductionType::SUM :
944     cpu_scatter_gather_base_kernel<>()(self, dim, index, value,
945                                        "scatter_scalar_reduce_add_", reduce_add);
946     break;
947   case ReductionType::PROD :
948     cpu_scatter_gather_base_kernel<>()(self, dim, index, value,
949                                        "scatter_scalar_reduce_multiply_", reduce_multiply);
950     break;
951   default:
952     break;
953   }
954 }
955 
956 } // anonymous namespace
957 
958 REGISTER_DISPATCH(gather_stub, &gather_cpu_kernel);
959 REGISTER_DISPATCH(scatter_stub, &scatter_cpu_kernel);
960 REGISTER_DISPATCH(scatter_fill_stub, &scatter_fill_cpu_kernel);
961 REGISTER_DISPATCH(scatter_add_stub, &scatter_add_cpu_kernel);
962 REGISTER_DISPATCH(scatter_reduce_stub, &scatter_reduce_cpu_kernel);
963 REGISTER_DISPATCH(scatter_scalar_reduce_stub, &scatter_scalar_reduce_cpu_kernel);
964 REGISTER_DISPATCH(scatter_reduce_two_stub, &scatter_reduce_two_cpu_kernel);
965 
966 // fast paths for GNN usage
967 REGISTER_DISPATCH(scatter_add_expanded_index_stub, &scatter_add_expanded_index_kernel);
968 REGISTER_DISPATCH(scatter_reduce_expanded_index_stub, &scatter_reduce_expanded_index_kernel);
969 REGISTER_DISPATCH(gather_expanded_index_stub, &gather_expanded_index_kernel);
970 
971 } // namespace at::native
972