xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/SpmmReduceKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/ExpandUtils.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/cpu/vec/functional.h>
7 #include <ATen/cpu/vec/vec.h>
8 #include <ATen/native/cpu/SpmmReduceKernel.h>
9 #include <ATen/native/cpu/ReduceUtils.h>
10 #include <ATen/native/cpu/utils.h>
11 #include <c10/util/irange.h>
12 #include <ATen/OpMathType.h>
13 
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #else
17 #include <ATen/ops/empty.h>
18 #include <ATen/ops/empty_native.h>
19 #include <ATen/ops/zeros.h>
20 #endif
21 
22 namespace at::native {
23 
24 namespace {
25 
26 template <typename scalar_t, typename index_t, ReductionType reduce>
_update(at::opmath_type<scalar_t> * out_ptr,int64_t e,int64_t c,const scalar_t val,const scalar_t * other_data,int64_t K)27 inline void _update(at::opmath_type<scalar_t>* out_ptr, int64_t e, int64_t c, const scalar_t val, const scalar_t* other_data, int64_t K) {
28   using opmath_t = at::opmath_type<scalar_t>;
29   using Vec = vec::Vectorized<scalar_t>;
30   using aVec = VecType<scalar_t>;
31   constexpr int64_t kVecSize = Vec::size();
32   constexpr int64_t kVLEN = kVecSize * 4;
33 
34   int64_t k = 0;
35   aVec val_vec = aVec((opmath_t)val);
36   const scalar_t* other_ptr = other_data + c * K;
37 
38   for (; k < K - (K % kVLEN); k += kVLEN) {
39     aVec out_vec0 = aVec::loadu(out_ptr + k);
40     aVec out_vec1 = aVec::loadu(out_ptr + k + kVecSize);
41     aVec out_vec2 = aVec::loadu(out_ptr + k + kVecSize * 2);
42     aVec out_vec3 = aVec::loadu(out_ptr + k + kVecSize * 3);
43 
44     out_vec0 = update<aVec, reduce>(out_vec0, aVec::loadu(other_ptr + k) * val_vec);
45     out_vec1 = update<aVec, reduce>(out_vec1, aVec::loadu(other_ptr + k + kVecSize) * val_vec);
46     out_vec2 = update<aVec, reduce>(out_vec2, aVec::loadu(other_ptr + k + kVecSize * 2) * val_vec);
47     out_vec3 = update<aVec, reduce>(out_vec3, aVec::loadu(other_ptr + k + kVecSize * 3) * val_vec);
48 
49     out_vec0.store(out_ptr + k);
50     out_vec1.store(out_ptr + k + kVecSize);
51     out_vec2.store(out_ptr + k + kVecSize * 2);
52     out_vec3.store(out_ptr + k + kVecSize * 3);
53   }
54   for (; k < K - (K % kVecSize); k += kVecSize) {
55     aVec out_vec = aVec::loadu(out_ptr + k);
56     out_vec = update<aVec, reduce>(out_vec, aVec::loadu(other_ptr + k) * val_vec);
57     out_vec.store(out_ptr + k);
58   }
59   for (; k < K; k++) {
60     opmath_t out_val = opmath_t(out_ptr[k]);
61     out_val = update<opmath_t, reduce>(out_val, opmath_t(other_ptr[k]) * opmath_t(val));
62     out_ptr[k] = out_val;
63   }
64 }
65 
66 template <typename scalar_t, typename index_t, ReductionType reduce>
spmm_reduce_kernel_impl(const Tensor & out,const Tensor & crow_indices,const Tensor & col_indices,const Tensor & values,const Tensor & other_)67 void spmm_reduce_kernel_impl(
68     const Tensor& out,
69     const Tensor& crow_indices,
70     const Tensor& col_indices,
71     const Tensor& values,
72     const Tensor& other_) {
73 
74   int64_t nnz = values.numel();
75   if (nnz == 0) {
76     return;
77   }
78 
79   auto other = other_.contiguous();
80 
81   // access `crow_indices`, `col_indices` and `values` via TensorAccessor
82   scalar_t* out_data = out.data_ptr<scalar_t>();
83   auto csr_data = crow_indices.accessor<const index_t, 1>();
84   auto col_data = col_indices.accessor<const index_t, 1>();
85   auto val_data = values.accessor<const scalar_t, 1>();
86   const scalar_t* other_data = other.const_data_ptr<scalar_t>();
87 
88   int64_t M = crow_indices.numel() - 1;
89   int64_t K = other.size(-1);
90 
91   int num_threads = at::get_num_threads();
92   using opmath_t = at::opmath_type<scalar_t>;
93   Tensor buffer;
94   opmath_t* buffer_data = nullptr;
95   static constexpr bool need_acc = is_reduced_floating_point_v<scalar_t>;
96   if constexpr (need_acc) {
97     auto acc_type = at::toAccumulateType(out.scalar_type(), /*is_cuda=*/true);
98     buffer = at::zeros({num_threads, K}, out.options().dtype(acc_type));
99     buffer_data = buffer.data_ptr<opmath_t>();
100   }
101 
102   utils::parallel_sparse_csr(csr_data, M, nnz, [&](int64_t begin, int64_t end) {
103     int tid = at::get_thread_num();
104     TORCH_CHECK(tid < num_threads,
105                 "expect thread id smaller than ", num_threads, ", got thread id ", tid);
106     opmath_t* buffer_ptr = nullptr;
107 
108     int64_t row_start = 0, row_end = 0;
109     for (const auto m : c10::irange(begin, end)) {
110       row_start = csr_data[m];
111       row_end = csr_data[m + 1];
112 
113       scalar_t* out_ptr = out_data + m * K;
114       if constexpr (need_acc) {
115         buffer_ptr = buffer_data + tid * K;
116       } else {
117         buffer_ptr = reinterpret_cast<opmath_t*>(out_ptr);
118       }
119 
120       // step 1: reinit the output row for reduce type 'amax' and 'amin'
121       int64_t count = row_end - row_start;
122       if (count != 0) {
123         _init<scalar_t, reduce>(out_ptr, buffer_ptr, K, /*include_self*/false);
124       }
125 
126       // step 2: reduce, do blocking on rowwise to reduce write memory bandwidth
127       constexpr int64_t CHUNK_SIZE = 16;
128       for (int64_t e0 = row_start; e0 < row_end; e0 += CHUNK_SIZE) {
129         int64_t e1 = std::min(e0 + CHUNK_SIZE, row_end);
130         for (const auto e : c10::irange(e0, e1)) {
131           int64_t c = col_data[e];
132           scalar_t val = val_data[e];
133           _update<scalar_t, index_t, reduce>(buffer_ptr, e, c, val, other_data, K);
134         }
135       }
136       if constexpr (need_acc) {
137         if (count != 0) {
138           vec::convert(buffer_ptr, out_ptr, K);
139         }
140       }
141 
142       // step 3: finalize
143       write<scalar_t, reduce>(out_ptr, count, K);
144     }
145   });
146 }
147 
148 // update both val and arg, used for `amin` and `amax`
149 // it is a little troublesome to vectorize it since `scalar_t` and `index_t`
150 // might have different vector length, for example, each vector holds 8 floats
151 // and 4 int64_t.
152 template <typename scalar_t, typename index_t, ReductionType reduce>
update_with_index(scalar_t * val,scalar_t new_val,index_t * arg,index_t new_arg)153 inline void update_with_index(scalar_t *val, scalar_t new_val, index_t *arg, index_t new_arg) {
154   if ((reduce == ReductionType::MIN && new_val < *val) ||
155       (reduce == ReductionType::MAX && new_val > *val) ||
156       at::_isnan<scalar_t>(new_val)) {
157     *val = new_val;
158     *arg = new_arg;
159   }
160 }
161 
162 template <typename scalar_t, typename index_t, ReductionType reduce>
spmm_reduce_arg_kernel_impl(const Tensor & out,const Tensor & arg_out,const Tensor & crow_indices,const Tensor & col_indices,const Tensor & values,const Tensor & other_)163 void spmm_reduce_arg_kernel_impl(
164     const Tensor& out,
165     const Tensor& arg_out,
166     const Tensor& crow_indices,
167     const Tensor& col_indices,
168     const Tensor& values,
169     const Tensor& other_) {
170 
171   TORCH_CHECK(reduce == ReductionType::MAX || reduce == ReductionType::MIN);
172   int64_t nnz = values.numel();
173   if (nnz == 0) {
174     return;
175   }
176 
177   auto other = other_.contiguous();
178 
179   scalar_t* out_data = out.data_ptr<scalar_t>();
180   index_t* arg_out_data = arg_out.data_ptr<index_t>();
181   auto csr_data = crow_indices.accessor<const index_t, 1>();
182   auto col_data = col_indices.accessor<const index_t, 1>();
183   auto val_data = values.accessor<const scalar_t, 1>();
184   const scalar_t* other_data = other.const_data_ptr<scalar_t>();
185 
186   int64_t M = crow_indices.numel() - 1;
187   int64_t K = other.size(-1);
188 
189   int num_threads = at::get_num_threads();
190   using opmath_t = at::opmath_type<scalar_t>;
191   Tensor buffer;
192   opmath_t* buffer_data = nullptr;
193   static constexpr bool need_acc = is_reduced_floating_point_v<scalar_t>;
194   if constexpr (need_acc) {
195     auto acc_type = at::toAccumulateType(out.scalar_type(), /*is_cuda=*/true);
196     buffer = at::zeros({num_threads, K}, out.options().dtype(acc_type));
197     buffer_data = buffer.data_ptr<opmath_t>();
198   }
199 
200   at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
201     int tid = at::get_thread_num();
202     TORCH_CHECK(tid < num_threads,
203                 "expect thread id smaller than ", num_threads, ", got thread id ", tid);
204     opmath_t* buffer_ptr = nullptr;
205 
206     int64_t row_start = 0, row_end = 0, c = 0;
207     for (const auto m : c10::irange(begin, end)) {
208       row_start = csr_data[m];
209       row_end = csr_data[m + 1];
210 
211       scalar_t* out_ptr = out_data + m * K;
212       index_t* arg_out_ptr = arg_out_data + m * K;
213       if constexpr (need_acc) {
214         buffer_ptr = buffer_data + tid * K;
215       } else {
216         buffer_ptr = reinterpret_cast<opmath_t*>(out_ptr);
217       }
218 
219       if (row_end != row_start) {
220         _init<scalar_t, reduce>(out_ptr, buffer_ptr, K, /*include_self*/false);
221         for (const auto e : c10::irange(row_start, row_end)) {
222           c = col_data[e];
223           opmath_t val = opmath_t(val_data[e]);
224 
225           const scalar_t* other_ptr = other_data + c * K;
226           for (const auto k : c10::irange(K)) {
227             update_with_index<opmath_t, index_t, reduce>(
228                 &buffer_ptr[k], opmath_t(val *  other_ptr[k]), &arg_out_ptr[k], index_t(e));
229           };
230         }
231       }
232       if constexpr (need_acc) {
233         if (row_end != row_start) {
234           vec::convert(buffer_ptr, out_ptr, K);
235         }
236       }
237     }
238   });
239 }
240 
241 template <typename scalar_t, typename index_t, ReductionType reduce>
spmm_reduce_backward_input_kernel_impl(const Tensor & grad_self,const Tensor & grad_out_,const Tensor & crow_indices,const Tensor & col_indices,const Tensor & other_,const Tensor & row_indices)242 void spmm_reduce_backward_input_kernel_impl(
243     const Tensor& grad_self,
244     const Tensor& grad_out_,
245     const Tensor& crow_indices,
246     const Tensor& col_indices,
247     const Tensor& other_,
248     const Tensor& row_indices) {
249 
250   int64_t nnz = grad_self._nnz();
251   if (nnz == 0) {
252     return;
253   }
254 
255   auto grad_out = grad_out_.contiguous();
256   auto other = other_.contiguous();
257 
258   auto values = grad_self.values();
259   auto grad_values_data = values.accessor<scalar_t, 1>();
260   const scalar_t* grad_out_data = grad_out.const_data_ptr<scalar_t>();
261   auto crow_data = crow_indices.accessor<const index_t, 1>();
262   auto col_data = col_indices.accessor<const index_t, 1>();
263   const scalar_t* other_data = other.const_data_ptr<scalar_t>();
264   auto row_data = row_indices.accessor<const index_t, 1>();
265 
266   int64_t K = grad_out.size(1);
267 
268   using Vec = vec::Vectorized<vec::vec_scalar_t<scalar_t>>;
269   at::parallel_for(0, nnz, 1, [&](int64_t begin, int64_t end) {
270     for (const auto i : c10::irange(begin, end)) {
271       index_t row = row_data[i], col = col_data[i];
272 
273       scalar_t val = vec::map2_reduce_all<scalar_t>(
274           [](Vec x, Vec y) { return x * y; },
275           [](Vec x, Vec y) { return x + y; },
276           other_data + col * K,
277           grad_out_data + row * K,
278           K);
279 
280       if (reduce == ReductionType::MEAN) {
281         index_t row_start = crow_data[row], row_end = crow_data[row + 1];
282         val /= (row_end - row_start);
283       }
284 
285       grad_values_data[i] = val;
286     }
287   });
288 }
289 
290 // backward for reduce type 'amax' or 'amin'
291 template <typename scalar_t, typename index_t>
spmm_reduce_backward_input_arg_kernel_impl(const Tensor & grad_self,const Tensor & grad_out_,const Tensor & col_indices,const Tensor & other_,const Tensor & arg_out_)292 void spmm_reduce_backward_input_arg_kernel_impl(
293     const Tensor& grad_self,
294     const Tensor& grad_out_,
295     const Tensor& col_indices,
296     const Tensor& other_,
297     const Tensor& arg_out_) {
298 
299   int64_t nnz = grad_self._nnz();
300   if (nnz == 0) {
301     return;
302   }
303 
304   auto grad_out = grad_out_.contiguous();
305   auto other = other_.contiguous();
306   auto arg_out = arg_out_.contiguous();
307 
308   auto grad_values = grad_self.values();
309   auto grad_values_data = grad_values.accessor<scalar_t, 1>();
310   const scalar_t* grad_out_data = grad_out.const_data_ptr<scalar_t>();
311   auto col_data = col_indices.accessor<const index_t, 1>();
312   const scalar_t* other_data = other.const_data_ptr<scalar_t>();
313   index_t* arg_out_data = arg_out.data_ptr<index_t>();
314 
315   int64_t M = grad_out.size(0);
316   int64_t K = grad_out.size(1);
317   auto grad = at::empty({M, K}, grad_out.options());
318   scalar_t* grad_data = grad.mutable_data_ptr<scalar_t>();
319 
320   at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
321     for (const auto m : c10::irange(begin, end)) {
322       const scalar_t* grad_out_ptr = grad_out_data + m * K;
323       scalar_t* grad_ptr = grad_data + m * K;
324       index_t* arg_out_ptr = arg_out_data + m * K;
325 
326       for (const auto k : c10::irange(K)) {
327         if (arg_out_ptr[k] == index_t(nnz)) {
328           grad_ptr[k] = scalar_t(0);
329         } else {
330           // collect weight at max/min indices
331           index_t col = col_data[arg_out_data[m * K + k]];
332           grad_ptr[k] = other_data[col * K + k] * grad_out_ptr[k];
333         }
334       }
335     }
336   });
337 
338   // scatter_add, consider to parallel this with atomic
339   for (const auto i : c10::irange(M * K)) {
340     index_t ind = arg_out_data[i];
341     if (ind != index_t(nnz)) {
342       grad_values_data[ind] += grad_data[i];
343     }
344   }
345 }
346 
347 template <typename scalar_t, typename index_t>
spmm_reduce_normalize_values_kernel_impl(const Tensor & normalized_values,const Tensor & values,const Tensor & crow_indices,const Tensor & row_indices)348 void spmm_reduce_normalize_values_kernel_impl(
349     const Tensor& normalized_values,
350     const Tensor& values,
351     const Tensor& crow_indices,
352     const Tensor& row_indices) {
353 
354   int64_t nnz = values.numel();
355   if (nnz == 0) {
356     return;
357   }
358 
359   auto normalized_values_data = normalized_values.accessor<scalar_t, 1>();
360   auto values_data = values.accessor<scalar_t, 1>();
361   auto crow_data = crow_indices.accessor<index_t, 1>();
362   auto row_data = row_indices.accessor<index_t, 1>();
363 
364   at::parallel_for(0, nnz, 1, [&](int64_t begin, int64_t end) {
365     for (const auto i : c10::irange(begin, end)) {
366       index_t row = row_data[i];
367       index_t row_start = crow_data[row], row_end = crow_data[row + 1];
368       // Note that when the row index row is listed in row_indices,
369       // then crow_indices[row+1] > crow_indices[row] holds
370       normalized_values_data[i] = values_data[i] / (row_end - row_start);
371     }
372   });
373 }
374 
375 template <typename scalar_t, typename index_t>
spmm_reduce_backward_other_arg_kernel_impl(const Tensor & grad_other,const Tensor & grad_out_,const Tensor & col_indices,const Tensor & values,const Tensor & arg_out_)376 void spmm_reduce_backward_other_arg_kernel_impl(
377     const Tensor& grad_other,
378     const Tensor& grad_out_,
379     const Tensor& col_indices,
380     const Tensor& values,
381     const Tensor& arg_out_) {
382 
383   int64_t nnz = values.numel();
384   if (nnz == 0) {
385     return;
386   }
387 
388   auto grad_out = grad_out_.contiguous();
389   auto arg_out = arg_out_.contiguous();
390 
391   scalar_t* grad_other_data = grad_other.data_ptr<scalar_t>();
392   const scalar_t* grad_out_data = grad_out.const_data_ptr<scalar_t>();
393   auto col_data = col_indices.accessor<const index_t, 1>();
394   auto values_data = values.accessor<const scalar_t, 1>();
395   const index_t* arg_out_data = arg_out.const_data_ptr<index_t>();
396 
397   int64_t M = grad_out.size(0);
398   int64_t K = grad_out.size(1);
399   auto grad = at::empty({M, K}, grad_out.options());
400   scalar_t* grad_data = grad.mutable_data_ptr<scalar_t>();
401 
402   at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
403     for (const auto m : c10::irange(begin, end)) {
404       const scalar_t* grad_out_ptr = grad_out_data + m * K;
405       scalar_t* grad_ptr = grad_data + m * K;
406       const index_t* arg_out_ptr = arg_out_data + m * K;
407 
408       for (const auto k : c10::irange(K)) {
409         if (arg_out_ptr[k] == index_t(nnz)) {
410           grad_ptr[k] = scalar_t(0);
411         } else {
412           grad_ptr[k] = values_data[arg_out_ptr[k]] * grad_out_ptr[k];
413         }
414       }
415     }
416   });
417 
418   // scatter_add, consider to parallel this with atomic
419   for (const auto m : c10::irange(M)) {
420     for (const auto k : c10::irange(K)) {
421       index_t ind = arg_out_data[m * K + k];
422       if (ind != index_t(nnz)) {
423         index_t col = col_data[ind];
424         grad_other_data[col * K + k] += grad_data[m * K + k];
425       }
426     }
427   }
428 }
429 
spmm_reduce_kernel(const Tensor & out,const Tensor & crow_indices,const Tensor & col_indices,const Tensor & values,const Tensor & other,ReductionType reduce_op)430 void spmm_reduce_kernel(
431     const Tensor& out,
432     const Tensor& crow_indices,
433     const Tensor& col_indices,
434     const Tensor& values,
435     const Tensor& other,
436     ReductionType reduce_op) {
437     AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_kernel", [&]() {
438       AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() {
439         AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
440           spmm_reduce_kernel_impl<scalar_t, index_t, reduce>(
441               out, crow_indices, col_indices, values, other);
442         });
443       });
444     });
445 }
446 
spmm_reduce_arg_kernel(const Tensor & out,const Tensor & arg_out,const Tensor & crow_indices,const Tensor & col_indices,const Tensor & values,const Tensor & other,ReductionType reduce_op)447 void spmm_reduce_arg_kernel(
448     const Tensor& out,
449     const Tensor& arg_out,
450     const Tensor& crow_indices,
451     const Tensor& col_indices,
452     const Tensor& values,
453     const Tensor& other,
454     ReductionType reduce_op) {
455   AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_kernel", [&]() {
456     AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() {
457       AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
458         spmm_reduce_arg_kernel_impl<scalar_t, index_t, reduce>(
459             out, arg_out, crow_indices, col_indices, values, other);
460       });
461     });
462   });
463 }
464 
spmm_reduce_backward_input_kernel(const Tensor & grad_self,const Tensor & grad_out,const Tensor & crow_indices,const Tensor & col_indices,const Tensor & other,const Tensor & row_indices,ReductionType reduce_op)465 void spmm_reduce_backward_input_kernel(
466     const Tensor& grad_self,
467     const Tensor& grad_out,
468     const Tensor& crow_indices,
469     const Tensor& col_indices,
470     const Tensor& other,
471     const Tensor& row_indices,
472     ReductionType reduce_op) {
473   TORCH_CHECK(reduce_op == ReductionType::SUM || reduce_op == ReductionType::MEAN);
474   AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, other.scalar_type(), "spmm_reduce_backward_input_kernel", [&]() {
475     AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_input_indices", [&]() {
476       AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
477         spmm_reduce_backward_input_kernel_impl<scalar_t, index_t, reduce>(
478             grad_self, grad_out, crow_indices, col_indices, other, row_indices);
479       });
480     });
481   });
482 }
483 
spmm_reduce_backward_input_arg_kernel(const Tensor & grad_self,const Tensor & grad_out,const Tensor & col_indices,const Tensor & other,const Tensor & arg_out,ReductionType reduce_op)484 void spmm_reduce_backward_input_arg_kernel(
485     const Tensor& grad_self,
486     const Tensor& grad_out,
487     const Tensor& col_indices,
488     const Tensor& other,
489     const Tensor& arg_out,
490     ReductionType reduce_op) {
491   TORCH_CHECK(reduce_op == ReductionType::MAX || reduce_op == ReductionType::MIN);
492   AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, other.scalar_type(), "spmm_reduce_backward_input_arg_kernel", [&]() {
493     AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_input_arg_indices", [&]() {
494       spmm_reduce_backward_input_arg_kernel_impl<scalar_t, index_t>(
495           grad_self, grad_out, col_indices, other, arg_out);
496     });
497   });
498 }
499 
spmm_reduce_normalize_values_kernel(const Tensor & normalized_values,const Tensor & values,const Tensor & crow_indices,const Tensor & row_indices)500 void spmm_reduce_normalize_values_kernel(
501     const Tensor& normalized_values,
502     const Tensor& values,
503     const Tensor& crow_indices,
504     const Tensor& row_indices) {
505   AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_normalize_values_kernel", [&]() {
506     AT_DISPATCH_INDEX_TYPES(crow_indices.scalar_type(), "spmm_reduce_normalize_values_indices", [&]() {
507       spmm_reduce_normalize_values_kernel_impl<scalar_t, index_t>(
508           normalized_values, values, crow_indices, row_indices);
509     });
510   });
511 }
512 
spmm_reduce_backward_other_kernel(const Tensor & grad_other,const Tensor & grad_out,const Tensor & crow_indices,const Tensor & values,const Tensor & row_indices,const Tensor & ccol_indices,const Tensor & csr2csc,ReductionType reduce_op)513 void spmm_reduce_backward_other_kernel(
514     const Tensor& grad_other,
515     const Tensor& grad_out,
516     const Tensor& crow_indices,
517     const Tensor& values,
518     const Tensor& row_indices,
519     const Tensor& ccol_indices,
520     const Tensor& csr2csc,
521     ReductionType reduce_op) {
522   TORCH_CHECK(reduce_op == ReductionType::SUM || reduce_op == ReductionType::MEAN);
523   // need to permute row_indices to CSC order
524   auto row = row_indices.index_select(0, csr2csc);
525 
526   Tensor val;
527   if (reduce_op == ReductionType::MEAN) {
528     // for reduce type "mean", need to normalize the values
529     // with rowcount for each of the nonzero element.
530     Tensor normalized_values = at::empty(values.sizes(), values.options());
531     spmm_reduce_normalize_values_kernel(normalized_values, values, crow_indices, row_indices);
532     val = normalized_values.index_select(0, csr2csc);
533   } else {
534     val = values.index_select(0, csr2csc);
535   }
536 
537   spmm_reduce_kernel(grad_other, ccol_indices, row, val, grad_out, ReductionType::SUM);
538 }
539 
spmm_reduce_backward_other_arg_kernel(const Tensor & grad_other,const Tensor & grad_out,const Tensor & col_indices,const Tensor & values,const Tensor & arg_out,ReductionType reduce_op)540 void spmm_reduce_backward_other_arg_kernel(
541     const Tensor& grad_other,
542     const Tensor& grad_out,
543     const Tensor& col_indices,
544     const Tensor& values,
545     const Tensor& arg_out,
546     ReductionType reduce_op) {
547   TORCH_CHECK(reduce_op == ReductionType::MAX || reduce_op == ReductionType::MIN);
548   AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_backward_other_arg_kernel", [&]() {
549     AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_other_arg_indices", [&]() {
550       spmm_reduce_backward_other_arg_kernel_impl<scalar_t, index_t>(
551           grad_other, grad_out, col_indices, values, arg_out);
552     });
553   });
554 }
555 
556 } // anonymous namespace
557 
558 REGISTER_DISPATCH(spmm_reduce_stub, &spmm_reduce_kernel);
559 REGISTER_DISPATCH(spmm_reduce_arg_stub, &spmm_reduce_arg_kernel);
560 REGISTER_DISPATCH(spmm_reduce_backward_input_stub, &spmm_reduce_backward_input_kernel);
561 REGISTER_DISPATCH(spmm_reduce_backward_input_arg_stub, &spmm_reduce_backward_input_arg_kernel);
562 REGISTER_DISPATCH(spmm_reduce_backward_other_stub, &spmm_reduce_backward_other_kernel);
563 REGISTER_DISPATCH(spmm_reduce_backward_other_arg_stub, &spmm_reduce_backward_other_arg_kernel);
564 
565 } // at::native
566