1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/NonEmptyUtils.h>
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/TensorIterator.h>
5 #include <ATen/native/TensorAdvancedIndexing.h>
6 #include <ATen/core/Tensor.h>
7 #include <ATen/Config.h>
8 #include <ATen/Dispatch.h>
9 #include <ATen/NumericUtils.h>
10 #include <ATen/Parallel.h>
11 #include <ATen/native/cpu/ReduceUtils.h>
12 #include <ATen/cpu/vec/functional.h>
13 #include <ATen/cpu/vec/vec.h>
14 #include <c10/util/irange.h>
15 #ifdef USE_FBGEMM
16 #include <fbgemm/Utils.h>
17 #endif
18 #include <ATen/OpMathType.h>
19
20 #ifndef AT_PER_OPERATOR_HEADERS
21 #include <ATen/Functions.h>
22 #include <ATen/NativeFunctions.h>
23 #else
24 #include <ATen/ops/empty.h>
25 #include <ATen/ops/zeros.h>
26 #endif
27 namespace at::native {
28
29 namespace {
30
31 // Implement as functors since lambdas don't get optimized.
32 class ReduceMultiply {
33 public:
34 template <typename scalar_t>
operator ()(at::opmath_type<scalar_t> * self_data,scalar_t * src_data) const35 constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
36 using opmath_t = at::opmath_type<scalar_t>;
37 *self_data *= opmath_t(*src_data);
38 }
39
operator ()(bool * self_data,bool * src_data) const40 constexpr void operator() (bool * self_data, bool * src_data) const {
41 *self_data = *self_data && *src_data;
42 }
43 };
44 static ReduceMultiply reduce_multiply;
45
46 class ReduceAdd {
47 public:
48 template <typename scalar_t>
operator ()(at::opmath_type<scalar_t> * self_data,scalar_t * src_data) const49 constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
50 using opmath_t = at::opmath_type<scalar_t>;
51 *self_data += opmath_t(*src_data);
52 }
53 };
54 static ReduceAdd reduce_add;
55
56 class ReduceMean {
57 public:
58 template <typename scalar_t>
operator ()(at::opmath_type<scalar_t> * self_data,scalar_t * src_data) const59 constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
60 using opmath_t = at::opmath_type<scalar_t>;
61 *self_data += opmath_t(*src_data);
62 }
63 };
64 static ReduceMean reduce_mean;
65
66 class ReduceMaximum {
67 public:
68 template <typename scalar_t>
operator ()(at::opmath_type<scalar_t> * self_data,scalar_t * src_data) const69 constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
70 using opmath_t = at::opmath_type<scalar_t>;
71 *self_data = at::_isnan<scalar_t>(*src_data) ? opmath_t(*src_data) : std::max(*self_data, opmath_t(*src_data));
72 }
73 };
74 static ReduceMaximum reduce_maximum;
75
76 class ReduceMinimum {
77 public:
78 template <typename scalar_t>
operator ()(at::opmath_type<scalar_t> * self_data,scalar_t * src_data) const79 constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
80 using opmath_t = at::opmath_type<scalar_t>;
81 *self_data = at::_isnan<scalar_t>(*src_data) ? opmath_t(*src_data) : std::min(*self_data, opmath_t(*src_data));
82 }
83 };
84 static ReduceMinimum reduce_minimum;
85
86 class TensorAssign {
87 public:
88 template <typename scalar_t>
operator ()(at::opmath_type<scalar_t> * self_data,scalar_t * src_data) const89 constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
90 using opmath_t = at::opmath_type<scalar_t>;
91 *self_data = opmath_t(*src_data);
92 }
93 };
94 static TensorAssign tensor_assign;
95
96 template <bool is_scatter_like = true>
97 struct _cpu_scatter_gather_dim_loop {
98 template <typename scalar_t, typename func_t>
operator ()at::native::__anon2d3241d40111::_cpu_scatter_gather_dim_loop99 void operator()(
100 at::opmath_type<scalar_t>* self_data, int64_t self_dim_stride,
101 int64_t* index_data, int64_t index_dim_stride,
102 scalar_t* src_data, int64_t src_dim_stride,
103 int64_t dim, int64_t index_dim_size,
104 int64_t index_upper_bound,
105 func_t& f
106 ) {
107
108 for (const auto i : c10::irange(index_dim_size)) {
109 int64_t idx_dim = index_data[i * index_dim_stride];
110 // we are not putting idx_dim in the error message because it disables
111 // loop optimization in clang-7
112 TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound,
113 "index ", index_data[i * index_dim_stride],
114 " is out of bounds for dimension ", dim,
115 " with size ", index_upper_bound
116 );
117
118 f(
119 self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride,
120 src_data + (is_scatter_like ? i : idx_dim) * src_dim_stride
121 );
122 }
123 }
124
125 template <typename scalar_t, typename func_t>
operator ()at::native::__anon2d3241d40111::_cpu_scatter_gather_dim_loop126 void operator()(
127 at::opmath_type<scalar_t>* self_data, int64_t self_dim_stride,
128 int64_t* index_data, int64_t index_dim_stride,
129 Scalar value,
130 int64_t dim, int64_t index_dim_size,
131 int64_t index_upper_bound,
132 func_t& f
133 ) {
134
135 for (const auto i : c10::irange(index_dim_size)) {
136 int64_t idx_dim = index_data[i * index_dim_stride];
137 // we are not putting idx_dim in the error message because it disables
138 // loop optimization in clang-7
139 TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound,
140 "index ", index_data[i * index_dim_stride],
141 " is out of bounds for dimension ", dim,
142 " with size ", index_upper_bound
143 );
144 auto temp = value.to<scalar_t>();
145 f(
146 self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride, &temp
147 );
148 }
149 }
150 };
151
create_acc_buffer(Tensor & buffer,const Tensor & self,bool need_acc)152 inline void create_acc_buffer(Tensor& buffer, const Tensor& self, bool need_acc) {
153 if (need_acc) {
154 auto acc_type = at::toOpMathType(self.scalar_type());
155 buffer = at::empty(self.sizes(), self.options().dtype(acc_type));
156 buffer.copy_(self);
157 } else {
158 buffer = self;
159 }
160 }
161
162 template <bool is_scatter_like = true>
163 struct cpu_scatter_gather_base_kernel {
164 template <typename func_t>
operator ()at::native::__anon2d3241d40111::cpu_scatter_gather_base_kernel165 void operator()(const Tensor& self, int64_t dim,
166 const Tensor& index, const Scalar& value,
167 const std::string& method_name, func_t& kernel_func) {
168
169 Tensor buffer;
170 bool need_acc = isReducedFloatingType(self.scalar_type());
171 create_acc_buffer(buffer, self, need_acc);
172
173 auto index_sizes = ensure_nonempty_vec(index.sizes().vec());
174 auto index_strides = ensure_nonempty_vec(index.strides().vec());
175
176 // `dim` is traversed in the kernel,
177 // that is why index.stride(dim) = 0 and index.size(dim) = 1.
178 // Also, index.size(dim) = 1 makes sure that TensorIterator.DimCounter
179 // has the following form : (i_1,..., i_{dim-1}, 0, i_{dim+1},...,i_n).
180 index_sizes[dim] = 1;
181 index_strides[dim] = 0;
182
183 auto iter = TensorIteratorConfig()
184 .check_all_same_dtype(false)
185 .resize_outputs(false)
186 // NOLINTNEXTLINE(bugprone-argument-comment)
187 .declare_static_shape(index.sizes(), /*squash_dim=*/dim)
188 .add_output(buffer)
189 .add_const_input(index)
190 .build();
191
192 auto self_dim_stride = ensure_nonempty_stride(buffer, dim);
193 auto self_dim_size = ensure_nonempty_size(buffer, dim);
194
195 auto index_dim_stride = ensure_nonempty_stride(index, dim);
196 auto index_dim_size = ensure_nonempty_size(index, dim);
197
198 auto index_upper_bound = self_dim_size;
199
200 // since the index dimension is squashed, need to alter the grain size according
201 // to keep equal granularity in parallelism.
202 int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / index_dim_size);
203
204 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
205 ScalarType::Bool, ScalarType::Half, ScalarType::BFloat16, self.scalar_type(),
206 "scatter_gather_scalar_cpu", [&] {
207 constexpr auto SELF_ITER_STRIDE_IDX = 0;
208 constexpr auto INDEX_ITER_STRIDE_IDX = 1;
209 using opmath_t = at::opmath_type<scalar_t>;
210 _cpu_scatter_gather_dim_loop<is_scatter_like> loop_func;
211 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
212 auto* self_data_bytes = data[SELF_ITER_STRIDE_IDX];
213 auto* index_data_bytes = data[INDEX_ITER_STRIDE_IDX];
214 // we change the order of TensorIterator-dim loop
215 // vs dim-TensorIterator loop order depending on
216 // whether dim is the last dimension
217 if (dim== buffer.dim() - 1) {
218 for (const auto nelem C10_UNUSED : c10::irange(n)) {
219 // dim loop is a separate code block
220 // for better performance
221 loop_func.template operator()<scalar_t, func_t>(
222 (opmath_t*)self_data_bytes, self_dim_stride,
223 (int64_t*)index_data_bytes, index_dim_stride,
224 value, dim, index_dim_size, index_upper_bound,
225 kernel_func);
226
227 self_data_bytes += strides[SELF_ITER_STRIDE_IDX];
228 index_data_bytes += strides[INDEX_ITER_STRIDE_IDX];
229 }
230 }
231 else {
232 for (const auto i : c10::irange(index_dim_size)) {
233 auto* self_data = self_data_bytes;
234 auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride);
235 for (const auto nelem C10_UNUSED : c10::irange(n)) {
236 int64_t idx_dim = *(int64_t*)index_data;
237 // we are not putting idx_dim in the error message because it disables
238 // loop optimization in clang-7
239 TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound,
240 "index ", *(int64_t*)index_data,
241 " is out of bounds for dimension ", dim,
242 " with size ", index_upper_bound);
243
244 auto temp = value.to<scalar_t>();
245 kernel_func((opmath_t*)self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride, &temp);
246
247 self_data += strides[SELF_ITER_STRIDE_IDX];
248 index_data += strides[INDEX_ITER_STRIDE_IDX];
249 }
250 }
251 }
252 };
253 iter.for_each(loop, grain_size);
254 }
255 );
256 if (need_acc) {
257 self.copy_(buffer);
258 }
259 }
260
261 template <typename func_t>
operator ()at::native::__anon2d3241d40111::cpu_scatter_gather_base_kernel262 void operator()(const Tensor& self, int64_t dim,
263 const Tensor& index, const Tensor& src,
264 const std::string& method_name, func_t& kernel_func) {
265
266 Tensor buffer;
267 bool need_acc = isReducedFloatingType(self.scalar_type());
268 create_acc_buffer(buffer, self, need_acc);
269
270 auto iter = TensorIteratorConfig()
271 .check_all_same_dtype(false)
272 .resize_outputs(false)
273 // NOLINTNEXTLINE(bugprone-argument-comment)
274 .declare_static_shape(index.sizes(), /*squash_dim=*/dim)
275 .add_output(buffer)
276 .add_const_input(src)
277 .add_const_input(index)
278 .build();
279
280 auto self_dim_stride = ensure_nonempty_stride(buffer, dim);
281 auto self_dim_size = ensure_nonempty_size(buffer, dim);
282
283 auto index_dim_stride = ensure_nonempty_stride(index, dim);
284 auto index_dim_size = ensure_nonempty_size(index, dim);
285
286 auto src_dim_stride = ensure_nonempty_stride(src, dim);
287 auto src_dim_size = ensure_nonempty_size(src, dim);
288
289 auto index_upper_bound = is_scatter_like ? self_dim_size : src_dim_size;
290
291 int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / index_dim_size);
292
293 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
294 ScalarType::Bool, ScalarType::Half, ScalarType::BFloat16, iter.dtype(1),
295 "scatter_gather_tensor_cpu", [&] {
296 constexpr auto SELF_ITER_STRIDE_IDX = 0;
297 constexpr auto INDEX_ITER_STRIDE_IDX = 2;
298 constexpr auto SRC_ITER_STRIDE_IDX = 1;
299 using opmath_t = at::opmath_type<scalar_t>;
300 _cpu_scatter_gather_dim_loop<is_scatter_like> loop_func;
301 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
302 auto* self_data_bytes = data[SELF_ITER_STRIDE_IDX];
303 auto* index_data_bytes = data[INDEX_ITER_STRIDE_IDX];
304 auto* src_data_bytes = data[SRC_ITER_STRIDE_IDX];
305 // we change the order of TensorIterator-dim loop
306 // vs dim-TensorIterator loop order depending on
307 // whether dim is the last dimension
308 if (dim== buffer.dim() - 1) {
309 for (const auto nelem C10_UNUSED : c10::irange(n)) {
310 // dim loop is a separate code block
311 // for better performance
312 loop_func.template operator()<scalar_t, func_t>(
313 (opmath_t*)self_data_bytes, self_dim_stride,
314 (int64_t*)index_data_bytes, index_dim_stride,
315 (scalar_t*)src_data_bytes, src_dim_stride,
316 dim, index_dim_size, index_upper_bound,
317 kernel_func
318 );
319
320 self_data_bytes += strides[SELF_ITER_STRIDE_IDX];
321 index_data_bytes += strides[INDEX_ITER_STRIDE_IDX];
322 src_data_bytes += strides[SRC_ITER_STRIDE_IDX];
323 }
324 }
325 else {
326 for (const auto i : c10::irange(index_dim_size)) {
327 auto* self_data = self_data_bytes;
328 auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride);
329 auto* src_data = src_data_bytes;
330 for (const auto nelem C10_UNUSED : c10::irange(n)) {
331 int64_t idx_dim = *(int64_t*)index_data;
332 // we are not putting idx_dim in the error message because it disables
333 // loop optimization in clang-7
334 TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound,
335 "index ", *(int64_t*)index_data,
336 " is out of bounds for dimension ", dim,
337 " with size ", index_upper_bound);
338
339 kernel_func(
340 (opmath_t*)self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride,
341 (scalar_t*)src_data + (is_scatter_like ? i : idx_dim) * src_dim_stride);
342
343 self_data += strides[SELF_ITER_STRIDE_IDX];
344 index_data += strides[INDEX_ITER_STRIDE_IDX];
345 src_data += strides[SRC_ITER_STRIDE_IDX];
346 }
347 }
348 }
349 };
350 iter.for_each(loop, grain_size);
351 }
352 );
353 if (need_acc) {
354 self.copy_(buffer);
355 }
356 }
357
operator ()at::native::__anon2d3241d40111::cpu_scatter_gather_base_kernel358 void operator()(const Tensor& self, int64_t dim,
359 const Tensor& index, const Tensor& src,
360 const std::string& method_name, ReduceMean& kernel_func) {
361
362 Tensor buffer;
363 bool need_acc = isReducedFloatingType(self.scalar_type());
364 create_acc_buffer(buffer, self, need_acc);
365
366 auto iter = TensorIteratorConfig()
367 .check_all_same_dtype(false)
368 .resize_outputs(false)
369 // NOLINTNEXTLINE(bugprone-argument-comment)
370 .declare_static_shape(index.sizes(), /*squash_dim=*/dim)
371 .add_output(buffer)
372 .add_const_input(src)
373 .add_const_input(index)
374 .build();
375
376 auto self_dim_stride = ensure_nonempty_stride(buffer, dim);
377 auto self_dim_size = ensure_nonempty_size(buffer, dim);
378
379 auto index_dim_stride = ensure_nonempty_stride(index, dim);
380 auto index_dim_size = ensure_nonempty_size(index, dim);
381
382 auto src_dim_stride = ensure_nonempty_stride(src, dim);
383 auto src_dim_size = ensure_nonempty_size(src, dim);
384
385 auto index_upper_bound = is_scatter_like ? self_dim_size : src_dim_size;
386
387 int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / index_dim_size);
388
389 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
390 ScalarType::Half, ScalarType::BFloat16, iter.dtype(1),
391 "scatter_gather_tensor_cpu_reduce_mean", [&] {
392 constexpr auto SELF_ITER_STRIDE_IDX = 0;
393 constexpr auto INDEX_ITER_STRIDE_IDX = 2;
394 constexpr auto SRC_ITER_STRIDE_IDX = 1;
395 using opmath_t = at::opmath_type<scalar_t>;
396 _cpu_scatter_gather_dim_loop<is_scatter_like> loop_func;
397 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
398 auto* self_data_bytes = data[SELF_ITER_STRIDE_IDX];
399 auto* index_data_bytes = data[INDEX_ITER_STRIDE_IDX];
400 auto* src_data_bytes = data[SRC_ITER_STRIDE_IDX];
401 // we change the order of TensorIterator-dim loop
402 // vs dim-TensorIterator loop order depending on
403 // whether dim is the last dimension
404 if (dim== buffer.dim() - 1) {
405 for (const auto nelem C10_UNUSED : c10::irange(n)) {
406 // dim loop is a separate code block
407 // for better performance
408 loop_func.template operator()<scalar_t, ReduceMean>(
409 (opmath_t*)self_data_bytes, self_dim_stride,
410 (int64_t*)index_data_bytes, index_dim_stride,
411 (scalar_t*)src_data_bytes, src_dim_stride,
412 dim, index_dim_size, index_upper_bound,
413 kernel_func
414 );
415
416 self_data_bytes += strides[SELF_ITER_STRIDE_IDX];
417 index_data_bytes += strides[INDEX_ITER_STRIDE_IDX];
418 src_data_bytes += strides[SRC_ITER_STRIDE_IDX];
419 }
420 }
421 else {
422 for (const auto i : c10::irange(index_dim_size)) {
423 auto* self_data = self_data_bytes;
424 auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride);
425 auto* src_data = src_data_bytes;
426 for (const auto nelem C10_UNUSED : c10::irange(n)) {
427 int64_t idx_dim = *(int64_t*)index_data;
428 // we are not putting idx_dim in the error message because it disables
429 // loop optimization in clang-7
430 TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound,
431 "index ", *(int64_t*)index_data,
432 " is out of bounds for dimension ", dim,
433 " with size ", index_upper_bound);
434
435 kernel_func(
436 (opmath_t*)self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride,
437 (scalar_t*)src_data + (is_scatter_like ? i : idx_dim) * src_dim_stride);
438
439 self_data += strides[SELF_ITER_STRIDE_IDX];
440 index_data += strides[INDEX_ITER_STRIDE_IDX];
441 src_data += strides[SRC_ITER_STRIDE_IDX];
442 }
443 }
444 }
445 };
446 iter.for_each(loop, grain_size);
447 }
448 );
449 if (need_acc) {
450 self.copy_(buffer);
451 }
452 }
453
operator ()at::native::__anon2d3241d40111::cpu_scatter_gather_base_kernel454 void operator()(const Tensor& self, int64_t dim,
455 const Tensor& index, const Tensor& src,
456 const std::string& method_name, ReduceMaximum& kernel_func) {
457 Tensor buffer;
458 bool need_acc = isReducedFloatingType(self.scalar_type());
459 create_acc_buffer(buffer, self, need_acc);
460
461 auto iter = TensorIteratorConfig()
462 .check_all_same_dtype(false)
463 .resize_outputs(false)
464 // NOLINTNEXTLINE(bugprone-argument-comment)
465 .declare_static_shape(index.sizes(), /*squash_dim=*/dim)
466 .add_output(buffer)
467 .add_const_input(src)
468 .add_const_input(index)
469 .build();
470
471 auto self_dim_stride = ensure_nonempty_stride(buffer, dim);
472 auto self_dim_size = ensure_nonempty_size(buffer, dim);
473
474 auto index_dim_stride = ensure_nonempty_stride(index, dim);
475 auto index_dim_size = ensure_nonempty_size(index, dim);
476
477 auto src_dim_stride = ensure_nonempty_stride(src, dim);
478 auto src_dim_size = ensure_nonempty_size(src, dim);
479
480 auto index_upper_bound = is_scatter_like ? self_dim_size : src_dim_size;
481
482 int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / index_dim_size);
483
484 AT_DISPATCH_ALL_TYPES_AND3(
485 ScalarType::Bool, ScalarType::Half, ScalarType::BFloat16, iter.dtype(1),
486 "scatter_gather_tensor_cpu_reduce_amax", [&] {
487 constexpr auto SELF_ITER_STRIDE_IDX = 0;
488 constexpr auto INDEX_ITER_STRIDE_IDX = 2;
489 constexpr auto SRC_ITER_STRIDE_IDX = 1;
490 using opmath_t = at::opmath_type<scalar_t>;
491 _cpu_scatter_gather_dim_loop<is_scatter_like> loop_func;
492 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
493 auto* self_data_bytes = data[SELF_ITER_STRIDE_IDX];
494 auto* index_data_bytes = data[INDEX_ITER_STRIDE_IDX];
495 auto* src_data_bytes = data[SRC_ITER_STRIDE_IDX];
496 // we change the order of TensorIterator-dim loop
497 // vs dim-TensorIterator loop order depending on
498 // whether dim is the last dimension
499 if (dim== buffer.dim() - 1) {
500 for (const auto nelem C10_UNUSED : c10::irange(n)) {
501 // dim loop is a separate code block
502 // for better performance
503 loop_func.template operator()<scalar_t, ReduceMaximum>(
504 (opmath_t*)self_data_bytes, self_dim_stride,
505 (int64_t*)index_data_bytes, index_dim_stride,
506 (scalar_t*)src_data_bytes, src_dim_stride,
507 dim, index_dim_size, index_upper_bound,
508 kernel_func
509 );
510
511 self_data_bytes += strides[SELF_ITER_STRIDE_IDX];
512 index_data_bytes += strides[INDEX_ITER_STRIDE_IDX];
513 src_data_bytes += strides[SRC_ITER_STRIDE_IDX];
514 }
515 }
516 else {
517 for (const auto i : c10::irange(index_dim_size)) {
518 auto* self_data = self_data_bytes;
519 auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride);
520 auto* src_data = src_data_bytes;
521 for (const auto nelem C10_UNUSED : c10::irange(n)) {
522 int64_t idx_dim = *(int64_t*)index_data;
523 // we are not putting idx_dim in the error message because it disables
524 // loop optimization in clang-7
525 TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound,
526 "index ", *(int64_t*)index_data,
527 " is out of bounds for dimension ", dim,
528 " with size ", index_upper_bound);
529
530 kernel_func(
531 (opmath_t*)self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride,
532 (scalar_t*)src_data + (is_scatter_like ? i : idx_dim) * src_dim_stride);
533
534 self_data += strides[SELF_ITER_STRIDE_IDX];
535 index_data += strides[INDEX_ITER_STRIDE_IDX];
536 src_data += strides[SRC_ITER_STRIDE_IDX];
537 }
538 }
539 }
540 };
541 iter.for_each(loop, grain_size);
542 }
543 );
544 if (need_acc) {
545 self.copy_(buffer);
546 }
547 }
548
operator ()at::native::__anon2d3241d40111::cpu_scatter_gather_base_kernel549 void operator()(const Tensor& self, int64_t dim,
550 const Tensor& index, const Tensor& src,
551 const std::string& method_name, ReduceMinimum& kernel_func) {
552
553 Tensor buffer;
554 bool need_acc = isReducedFloatingType(self.scalar_type());
555 create_acc_buffer(buffer, self, need_acc);
556
557 auto iter = TensorIteratorConfig()
558 .check_all_same_dtype(false)
559 .resize_outputs(false)
560 // NOLINTNEXTLINE(bugprone-argument-comment)
561 .declare_static_shape(index.sizes(), /*squash_dim=*/dim)
562 .add_output(buffer)
563 .add_const_input(src)
564 .add_const_input(index)
565 .build();
566
567 auto self_dim_stride = ensure_nonempty_stride(buffer, dim);
568 auto self_dim_size = ensure_nonempty_size(buffer, dim);
569
570 auto index_dim_stride = ensure_nonempty_stride(index, dim);
571 auto index_dim_size = ensure_nonempty_size(index, dim);
572
573 auto src_dim_stride = ensure_nonempty_stride(src, dim);
574 auto src_dim_size = ensure_nonempty_size(src, dim);
575
576 auto index_upper_bound = is_scatter_like ? self_dim_size : src_dim_size;
577
578 int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / index_dim_size);
579
580 AT_DISPATCH_ALL_TYPES_AND3(
581 ScalarType::Bool, ScalarType::Half, ScalarType::BFloat16, iter.dtype(1),
582 "scatter_gather_tensor_cpu_reduce_amin", [&] {
583 constexpr auto SELF_ITER_STRIDE_IDX = 0;
584 constexpr auto INDEX_ITER_STRIDE_IDX = 2;
585 constexpr auto SRC_ITER_STRIDE_IDX = 1;
586 using opmath_t = at::opmath_type<scalar_t>;
587 _cpu_scatter_gather_dim_loop<is_scatter_like> loop_func;
588 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
589 auto* self_data_bytes = data[SELF_ITER_STRIDE_IDX];
590 auto* index_data_bytes = data[INDEX_ITER_STRIDE_IDX];
591 auto* src_data_bytes = data[SRC_ITER_STRIDE_IDX];
592 // we change the order of TensorIterator-dim loop
593 // vs dim-TensorIterator loop order depending on
594 // whether dim is the last dimension
595 if (dim== buffer.dim() - 1) {
596 for (const auto nelem C10_UNUSED : c10::irange(n)) {
597 // dim loop is a separate code block
598 // for better performance
599 loop_func.template operator()<scalar_t, ReduceMinimum>(
600 (opmath_t*)self_data_bytes, self_dim_stride,
601 (int64_t*)index_data_bytes, index_dim_stride,
602 (scalar_t*)src_data_bytes, src_dim_stride,
603 dim, index_dim_size, index_upper_bound,
604 kernel_func
605 );
606
607 self_data_bytes += strides[SELF_ITER_STRIDE_IDX];
608 index_data_bytes += strides[INDEX_ITER_STRIDE_IDX];
609 src_data_bytes += strides[SRC_ITER_STRIDE_IDX];
610 }
611 }
612 else {
613 for (const auto i : c10::irange(index_dim_size)) {
614 auto* self_data = self_data_bytes;
615 auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride);
616 auto* src_data = src_data_bytes;
617 for (const auto nelem C10_UNUSED : c10::irange(n)) {
618 int64_t idx_dim = *(int64_t*)index_data;
619 // we are not putting idx_dim in the error message because it disables
620 // loop optimization in clang-7
621 TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound,
622 "index ", *(int64_t*)index_data,
623 " is out of bounds for dimension ", dim,
624 " with size ", index_upper_bound);
625
626 kernel_func(
627 (opmath_t*)self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride,
628 (scalar_t*)src_data + (is_scatter_like ? i : idx_dim) * src_dim_stride);
629
630 self_data += strides[SELF_ITER_STRIDE_IDX];
631 index_data += strides[INDEX_ITER_STRIDE_IDX];
632 src_data += strides[SRC_ITER_STRIDE_IDX];
633 }
634 }
635 }
636 };
637 iter.for_each(loop, grain_size);
638 }
639 );
640 if (need_acc) {
641 self.copy_(buffer);
642 }
643 }
644 };
645
646 #ifndef USE_FBGEMM
647 namespace fbgemm {
648
649 template <typename K, typename V>
radix_sort_parallel(K * const inp_key_buf,V * const inp_value_buf,K * const tmp_key_buf,V * const tmp_value_buf,const int64_t elements_count,const int64_t max_value)650 std::pair<K*, V*> radix_sort_parallel(
651 K* const inp_key_buf,
652 V* const inp_value_buf,
653 K* const tmp_key_buf,
654 V* const tmp_value_buf,
655 const int64_t elements_count,
656 const int64_t max_value) {
657 TORCH_INTERNAL_ASSERT(false, "radix_sort_parallel: ATen not compiled with FBGEMM support");
658 return std::make_pair(nullptr, nullptr);
659 }
660
661 }
662 #endif
663
664 // Note [scatter reduce optimization]
665 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
666 //
667 // 1. initiative: optimize `scatter_reduce` on classic PyG use-case:
668 // `scatter_reduce` is extensively used on 'message passing' when
669 // aggregating info.
670 //
671 // Typically, `self` will 2D tensor and `index` is a 1D extended/broadcasted
672 // tensor, which means that the aggregation is on rowwise and we can vectorize
673 // on the inner dimensions.
674 //
675 // 2. implementation: map `scatter_reduce` to `spmm` reduce
676 // in the shape of `[M, N]` * `[N, K]`, where:
677 //
678 // M: self_dim_size
679 // nnz: index_dim_size
680 // K: index.numel() / index_dim_size;
681 //
682 // step 1: convert input index to CSR format (use radix_sort to
683 // solve write addr conflicts on `self` tensor)
684 //
685 // step 2: spmm reduce, parallel on M and vectorize on K
686 //
687
688 template <typename scalar_t, ReductionType reduce>
cpu_scatter_reduce_expanded_index(const Tensor & self,const Tensor & index,const Tensor & src,bool include_self)689 void cpu_scatter_reduce_expanded_index(const Tensor& self, const Tensor& index, const Tensor& src, bool include_self) {
690 const int64_t* index_data = index.const_data_ptr<int64_t>();
691 scalar_t* self_data = self.data_ptr<scalar_t>();
692 const scalar_t* src_data = src.const_data_ptr<scalar_t>();
693
694 const int64_t M = ensure_nonempty_size(self, 0);
695 const int64_t nnz = ensure_nonempty_size(index, 0);
696 const int64_t K = index.numel() / nnz;
697
698 const int64_t index_upper_bound = M;
699
700 auto keys = std::make_unique<int64_t[]>(nnz);
701 auto values = std::make_unique<int64_t[]>(nnz);
702 auto keys_tmp = std::make_unique<int64_t[]>(nnz);
703 auto values_tmp = std::make_unique<int64_t[]>(nnz);
704 at::parallel_for(0, nnz, 1, [&](int64_t begin, int64_t end) {
705 for (const auto i : c10::irange(begin, end)) {
706 int64_t index = index_data[i];
707 TORCH_CHECK(index >= 0 && index < index_upper_bound,
708 "index ", index,
709 " is out of bounds for dimension ", 0,
710 " with size ", index_upper_bound);
711 keys[i] = index;
712 values[i] = i;
713 }
714 });
715
716 int64_t* sorted_col_index_keys = nullptr;
717 int64_t* sorted_col_index_values = nullptr;
718 std::tie(sorted_col_index_keys, sorted_col_index_values) = fbgemm::radix_sort_parallel(
719 keys.get(),
720 values.get(),
721 keys_tmp.get(),
722 values_tmp.get(),
723 nnz,
724 M);
725
726 int num_threads = at::get_num_threads();
727 std::vector<int64_t> num_uniq(num_threads, 0);
728 at::parallel_for(1, nnz, 1, [&](int64_t begin, int64_t end) {
729 int tid = at::get_thread_num();
730 for(const auto i : c10::irange(begin, end)) {
731 if (sorted_col_index_keys[i] != sorted_col_index_keys[i - 1]) {
732 num_uniq[tid]++;
733 }
734 }
735 });
736 num_uniq[0]++;
737 for (const auto n : c10::irange(1, num_threads)) {
738 num_uniq[n] += num_uniq[n - 1];
739 }
740
741 // in case some rows are not written into, num_nonzero_rows will be smaller than M
742 int64_t num_nonzero_rows = num_uniq[num_threads - 1];
743 auto row_index_tmp = std::make_unique<int64_t[]>(num_nonzero_rows);
744 auto row_index_offset_tmp = std::make_unique<int64_t[]>(num_nonzero_rows + 1);
745 int64_t* row_index = row_index_tmp.get();
746 int64_t* row_index_offset = row_index_offset_tmp.get();
747 row_index[0] = sorted_col_index_keys[0];
748 row_index_offset[0] = 0;
749 row_index_offset[num_nonzero_rows] = nnz;
750
751 at::parallel_for(1, nnz, 1, [&](int64_t begin, int64_t end) {
752 int tid = at::get_thread_num();
753 int64_t* t_index = row_index + ((tid == 0) ? 1 : num_uniq[tid - 1]);
754 int64_t* t_index_offset = row_index_offset + ((tid == 0) ? 1 : num_uniq[tid - 1]);
755 for (const auto i : c10::irange(begin, end)) {
756 if (sorted_col_index_keys[i] != sorted_col_index_keys[i - 1]) {
757 *t_index = sorted_col_index_keys[i];
758 *t_index_offset = i;
759 t_index++;
760 t_index_offset++;
761 }
762 }
763 });
764
765 using opmath_t = at::opmath_type<scalar_t>;
766 Tensor buffer;
767 opmath_t* buffer_data = nullptr;
768 static constexpr bool need_acc = is_reduced_floating_point_v<scalar_t>;
769 if constexpr (need_acc) {
770 auto acc_type = at::toAccumulateType(self.scalar_type(), /*is_cuda=*/true);
771 buffer = at::zeros({num_threads, K}, self.options().dtype(acc_type));
772 buffer_data = buffer.data_ptr<opmath_t>();
773 }
774
775 // TODO: do blocking on col dimension to reduce WR bandwidth
776 at::parallel_for(0, num_nonzero_rows, 1, [&](int64_t begin, int64_t end) {
777 int tid = at::get_thread_num();
778 TORCH_CHECK(tid < num_threads,
779 "expect thread id smaller than ", num_threads, ", got thread id ", tid);
780 opmath_t* buffer_ptr = nullptr;
781
782 for (const auto m : c10::irange(begin, end)) {
783 int64_t row = row_index[m];
784 int64_t off_start = row_index_offset[m];
785 int64_t off_end = row_index_offset[m + 1];
786 scalar_t* self_ptr = self_data + row * K;
787 if constexpr (need_acc) {
788 buffer_ptr = buffer_data + tid * K;
789 } else {
790 buffer_ptr = reinterpret_cast<opmath_t*>(self_ptr);
791 }
792
793 // step 1: reinit rows in `self` if needed
794 _init<scalar_t, reduce>(self_ptr, buffer_ptr, K, include_self);
795
796 // step 2: reduce
797 for (const auto n : c10::irange(off_start, off_end)) {
798 int64_t col = sorted_col_index_values[n];
799 update<scalar_t, reduce>(buffer_ptr, src_data + col * K, K);
800 }
801 if constexpr (need_acc) {
802 vec::convert(buffer_ptr, self_ptr, K);
803 }
804
805 // step 3: finalize
806 int64_t count = include_self ? 1 : 0;
807 count += off_end - off_start;
808 write<scalar_t, reduce>(self_ptr, count, K);
809 }
810 });
811 }
812
813 template <typename scalar_t>
cpu_gather_expanded_index_kernel(const Tensor & result,const Tensor & index,const Tensor & self)814 void cpu_gather_expanded_index_kernel(const Tensor& result, const Tensor& index, const Tensor& self) {
815 const int64_t* index_data = index.const_data_ptr<int64_t>();
816 scalar_t* result_data = result.data_ptr<scalar_t>();
817 const scalar_t* self_data = self.const_data_ptr<scalar_t>();
818
819 const int64_t M = ensure_nonempty_size(result, 0);
820 const int64_t N = ensure_nonempty_size(self, 0);
821 const int64_t K = index.numel() / M;
822
823 const int64_t index_upper_bound = N;
824
825 using Vec = vec::Vectorized<scalar_t>;
826 int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / K);
827 at::parallel_for(0, M, grain_size, [&](int64_t begin, int64_t end) {
828 for (const auto m : c10::irange(begin, end)) {
829 scalar_t* result_ptr = result_data + m * K;
830 int64_t index = index_data[m];
831 TORCH_CHECK(index >= 0 && index < index_upper_bound,
832 "index ", index,
833 " is out of bounds for dimension ", 0,
834 " with size ", index_upper_bound);
835 const scalar_t* self_ptr = self_data + index * K;
836 int64_t d = 0;
837 for (; d < K - (K % Vec::size()); d += Vec::size()) {
838 Vec out_vec = Vec::loadu(self_ptr + d);
839 out_vec.store(result_ptr + d);
840 }
841 #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
842 # pragma unroll
843 #endif
844 for (; d < K; d++) {
845 result_ptr[d] = self_ptr[d];
846 }
847 }
848 });
849 }
850
scatter_add_expanded_index_kernel(const Tensor & self,const Tensor & index,const Tensor & src)851 void scatter_add_expanded_index_kernel(const Tensor& self, const Tensor& index, const Tensor& src) {
852 AT_DISPATCH_FLOATING_TYPES_AND2(
853 ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "scatter_add_expanded_index", [&] {
854 cpu_scatter_reduce_expanded_index<scalar_t, ReductionType::SUM>(self, index, src, /*include_self*/true);
855 });
856 }
857
scatter_reduce_expanded_index_kernel(const Tensor & self,const Tensor & index,const Tensor & src,const ReductionType & reduction,bool include_self)858 void scatter_reduce_expanded_index_kernel(
859 const Tensor& self, const Tensor& index, const Tensor& src,
860 const ReductionType& reduction, bool include_self) {
861 AT_DISPATCH_FLOATING_TYPES_AND2(
862 ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "scatter_reduce_expanded_index", [&] {
863 AT_DISPATCH_REDUCTION_TYPES(reduction, [&]() {
864 cpu_scatter_reduce_expanded_index<scalar_t, reduce>(self, index, src, include_self);
865 });
866 });
867 }
868
gather_expanded_index_kernel(const Tensor & result,const Tensor & self,const Tensor & index)869 void gather_expanded_index_kernel(const Tensor& result, const Tensor& self, const Tensor& index) {
870 AT_DISPATCH_FLOATING_TYPES_AND2(
871 ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "gather_expanded_index", [&] {
872 cpu_gather_expanded_index_kernel<scalar_t>(result, index, self);
873 });
874 }
875
gather_cpu_kernel(const Tensor & result,const Tensor & self,int64_t dim,const Tensor & index)876 void gather_cpu_kernel(const Tensor& result, const Tensor& self, int64_t dim, const Tensor& index) {
877 cpu_scatter_gather_base_kernel</*is_scatter_like=*/false>()(
878 result, dim, index, self,
879 "gather_out_cpu", tensor_assign);
880 }
881
scatter_cpu_kernel(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & src)882 void scatter_cpu_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) {
883 cpu_scatter_gather_base_kernel<>()(
884 self, dim, index, src, "scatter_cpu_", tensor_assign);
885 }
886
scatter_fill_cpu_kernel(const Tensor & self,int64_t dim,const Tensor & index,const Scalar & value)887 void scatter_fill_cpu_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& value) {
888 cpu_scatter_gather_base_kernel<>()(
889 self, dim, index, value, "scatter_fill_cpu_", tensor_assign);
890 }
891
scatter_add_cpu_kernel(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & src)892 void scatter_add_cpu_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) {
893 cpu_scatter_gather_base_kernel<>()(
894 self, dim, index, src,
895 "scatter_add_", reduce_add);
896 }
897
scatter_reduce_cpu_kernel(const Tensor & self,const int64_t dim,const Tensor & index,const Tensor & src,const ReductionType & reduce)898 void scatter_reduce_cpu_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
899 const Tensor& src, const ReductionType& reduce) {
900 switch (reduce) {
901 case ReductionType::SUM :
902 cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
903 "scatter_reduce_add_", reduce_add);
904 break;
905 case ReductionType::PROD :
906 cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
907 "scatter_reduce_multiply_", reduce_multiply);
908 break;
909 default :
910 break;
911 }
912 }
913
scatter_reduce_two_cpu_kernel(const Tensor & self,const int64_t dim,const Tensor & index,const Tensor & src,const ReductionType & reduce)914 void scatter_reduce_two_cpu_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
915 const Tensor& src, const ReductionType& reduce) {
916 switch (reduce) {
917 case ReductionType::SUM :
918 cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
919 "scatter_reduce_sum_", reduce_add);
920 break;
921 case ReductionType::PROD :
922 cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
923 "scatter_reduce_prod_", reduce_multiply);
924 break;
925 case ReductionType::MAX :
926 cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
927 "scatter_reduce_amax_", reduce_maximum);
928 break;
929 case ReductionType::MIN :
930 cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
931 "scatter_reduce_amin_", reduce_minimum);
932 break;
933 case ReductionType::MEAN :
934 cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
935 "scatter_reduce_mean_", reduce_mean);
936 break;
937 }
938 }
939
scatter_scalar_reduce_cpu_kernel(const Tensor & self,const int64_t dim,const Tensor & index,const Scalar & value,const ReductionType & reduce)940 void scatter_scalar_reduce_cpu_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
941 const Scalar& value, const ReductionType& reduce) {
942 switch (reduce) {
943 case ReductionType::SUM :
944 cpu_scatter_gather_base_kernel<>()(self, dim, index, value,
945 "scatter_scalar_reduce_add_", reduce_add);
946 break;
947 case ReductionType::PROD :
948 cpu_scatter_gather_base_kernel<>()(self, dim, index, value,
949 "scatter_scalar_reduce_multiply_", reduce_multiply);
950 break;
951 default:
952 break;
953 }
954 }
955
956 } // anonymous namespace
957
958 REGISTER_DISPATCH(gather_stub, &gather_cpu_kernel);
959 REGISTER_DISPATCH(scatter_stub, &scatter_cpu_kernel);
960 REGISTER_DISPATCH(scatter_fill_stub, &scatter_fill_cpu_kernel);
961 REGISTER_DISPATCH(scatter_add_stub, &scatter_add_cpu_kernel);
962 REGISTER_DISPATCH(scatter_reduce_stub, &scatter_reduce_cpu_kernel);
963 REGISTER_DISPATCH(scatter_scalar_reduce_stub, &scatter_scalar_reduce_cpu_kernel);
964 REGISTER_DISPATCH(scatter_reduce_two_stub, &scatter_reduce_two_cpu_kernel);
965
966 // fast paths for GNN usage
967 REGISTER_DISPATCH(scatter_add_expanded_index_stub, &scatter_add_expanded_index_kernel);
968 REGISTER_DISPATCH(scatter_reduce_expanded_index_stub, &scatter_reduce_expanded_index_kernel);
969 REGISTER_DISPATCH(gather_expanded_index_stub, &gather_expanded_index_kernel);
970
971 } // namespace at::native
972