xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/FunctionOfAMatrixUtilsKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/FunctionOfAMatrixUtils.h>
3 
4 #include <ATen/Dispatch.h>
5 #include <ATen/native/cuda/Loops.cuh>
6 #include <ATen/cuda/detail/OffsetCalculator.cuh>
7 #include <ATen/cuda/Atomic.cuh>
8 #include <ATen/cuda/CUDAContext.h>
9 
10 namespace at::native {
11 
12 namespace {
13 
14 template <int n_threads, int n_elems_per_thread, typename func_t>
C10_LAUNCH_BOUNDS_2(n_threads,n_elems_per_thread)15 C10_LAUNCH_BOUNDS_2(n_threads, n_elems_per_thread)
16 __global__ void _elemwise_kernel(int total_n_elems, func_t f) {
17   constexpr int total_work_block = n_threads * n_elems_per_thread;
18   int idx = total_work_block * blockIdx.x + threadIdx.x;
19 
20   #pragma unroll
21   for (int i = 0; i < n_elems_per_thread; ++i) {
22     if (idx < total_n_elems) {
23       f(idx);
24       idx += n_threads;
25     }
26   }
27 }
28 
29 template <int n_threads, int n_elems_per_thread, typename func_t>
_lauch_kernel(int total_n_elems,const func_t & f)30 void _lauch_kernel(int total_n_elems, const func_t& f) {
31   TORCH_INTERNAL_ASSERT(
32     total_n_elems >= 0 && total_n_elems <= std::numeric_limits<int32_t>::max()
33   );
34 
35   dim3 block(n_threads);
36   constexpr int total_work_block = n_threads * n_elems_per_thread;
37   dim3 grid((total_n_elems + total_work_block - 1) / total_work_block);
38 
39   auto stream = at::cuda::getCurrentCUDAStream();
40   _elemwise_kernel<n_threads, n_elems_per_thread, func_t>
41     <<<grid, block, 0, stream>>>(total_n_elems, f);
42   C10_CUDA_KERNEL_LAUNCH_CHECK();
43 }
44 
45 template <typename scalar_t>
_compute_linear_combination_internal_kernel(TensorIterator & iter,int32_t in_stride,int32_t coeff_stride,int32_t num_summations)46 void _compute_linear_combination_internal_kernel(
47   TensorIterator& iter,
48   int32_t in_stride,
49   int32_t coeff_stride,
50   int32_t num_summations
51 ) {
52   if (iter.numel() == 0) {
53     return;
54   }
55 
56   if (!iter.can_use_32bit_indexing()) {
57     for (auto& sub_iter : iter.with_32bit_indexing()) {
58       _compute_linear_combination_internal_kernel<scalar_t>(
59         sub_iter, in_stride, coeff_stride, num_summations
60       );
61     }
62     return;
63   }
64 
65   auto offset_calc = make_offset_calculator<3>(iter);
66   char* __restrict__ out_ptr = reinterpret_cast<char*>(iter.data_ptr(0));
67   char* __restrict__ in_ptr = reinterpret_cast<char*>(iter.data_ptr(1));
68   char* __restrict__ coeff_ptr = reinterpret_cast<char*>(iter.data_ptr(2));
69 
70   auto loop = [=]C10_DEVICE(int idx) {
71     auto offsets = offset_calc.get(idx);
72 
73     auto* __restrict__ out_data = reinterpret_cast<scalar_t*>(
74       out_ptr + offsets[0]
75     );
76     auto* __restrict__ in_data = reinterpret_cast<scalar_t*>(
77       in_ptr + offsets[1]
78     );
79     using primitive_t = typename scalar_value_type<scalar_t>::type;
80     auto* __restrict__ coeff_data = reinterpret_cast<primitive_t*>(
81       coeff_ptr + offsets[2]
82     );
83 
84     // perform summation
85     for (int32_t i = 0; i < num_summations; ++i) {
86       *out_data += in_data[i * in_stride] * coeff_data[i * coeff_stride];
87     }
88   };
89 
90   _lauch_kernel<num_threads(), thread_work_size()>(iter.numel(), loop);
91 }
92 
_compute_linear_combination_cuda_kernel(TensorIterator & iter,int64_t in_stride,int64_t coeff_stride,int64_t num_summations)93 void _compute_linear_combination_cuda_kernel(
94   TensorIterator& iter,
95   int64_t in_stride,
96   int64_t coeff_stride,
97   int64_t num_summations
98 ) {
99   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
100     at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
101     iter.dtype(),
102     "_compute_linear_combination_cuda", [&] () {
103       _compute_linear_combination_internal_kernel<scalar_t>(
104         iter, in_stride, coeff_stride, num_summations
105       );
106     }
107   );
108 }
109 
110 }
111 
112 REGISTER_DISPATCH(_compute_linear_combination_stub, &_compute_linear_combination_cuda_kernel);
113 
114 } // namespace at::native
115