1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/FunctionOfAMatrixUtils.h>
3
4 #include <ATen/Dispatch.h>
5 #include <ATen/native/cuda/Loops.cuh>
6 #include <ATen/cuda/detail/OffsetCalculator.cuh>
7 #include <ATen/cuda/Atomic.cuh>
8 #include <ATen/cuda/CUDAContext.h>
9
10 namespace at::native {
11
12 namespace {
13
14 template <int n_threads, int n_elems_per_thread, typename func_t>
C10_LAUNCH_BOUNDS_2(n_threads,n_elems_per_thread)15 C10_LAUNCH_BOUNDS_2(n_threads, n_elems_per_thread)
16 __global__ void _elemwise_kernel(int total_n_elems, func_t f) {
17 constexpr int total_work_block = n_threads * n_elems_per_thread;
18 int idx = total_work_block * blockIdx.x + threadIdx.x;
19
20 #pragma unroll
21 for (int i = 0; i < n_elems_per_thread; ++i) {
22 if (idx < total_n_elems) {
23 f(idx);
24 idx += n_threads;
25 }
26 }
27 }
28
29 template <int n_threads, int n_elems_per_thread, typename func_t>
_lauch_kernel(int total_n_elems,const func_t & f)30 void _lauch_kernel(int total_n_elems, const func_t& f) {
31 TORCH_INTERNAL_ASSERT(
32 total_n_elems >= 0 && total_n_elems <= std::numeric_limits<int32_t>::max()
33 );
34
35 dim3 block(n_threads);
36 constexpr int total_work_block = n_threads * n_elems_per_thread;
37 dim3 grid((total_n_elems + total_work_block - 1) / total_work_block);
38
39 auto stream = at::cuda::getCurrentCUDAStream();
40 _elemwise_kernel<n_threads, n_elems_per_thread, func_t>
41 <<<grid, block, 0, stream>>>(total_n_elems, f);
42 C10_CUDA_KERNEL_LAUNCH_CHECK();
43 }
44
45 template <typename scalar_t>
_compute_linear_combination_internal_kernel(TensorIterator & iter,int32_t in_stride,int32_t coeff_stride,int32_t num_summations)46 void _compute_linear_combination_internal_kernel(
47 TensorIterator& iter,
48 int32_t in_stride,
49 int32_t coeff_stride,
50 int32_t num_summations
51 ) {
52 if (iter.numel() == 0) {
53 return;
54 }
55
56 if (!iter.can_use_32bit_indexing()) {
57 for (auto& sub_iter : iter.with_32bit_indexing()) {
58 _compute_linear_combination_internal_kernel<scalar_t>(
59 sub_iter, in_stride, coeff_stride, num_summations
60 );
61 }
62 return;
63 }
64
65 auto offset_calc = make_offset_calculator<3>(iter);
66 char* __restrict__ out_ptr = reinterpret_cast<char*>(iter.data_ptr(0));
67 char* __restrict__ in_ptr = reinterpret_cast<char*>(iter.data_ptr(1));
68 char* __restrict__ coeff_ptr = reinterpret_cast<char*>(iter.data_ptr(2));
69
70 auto loop = [=]C10_DEVICE(int idx) {
71 auto offsets = offset_calc.get(idx);
72
73 auto* __restrict__ out_data = reinterpret_cast<scalar_t*>(
74 out_ptr + offsets[0]
75 );
76 auto* __restrict__ in_data = reinterpret_cast<scalar_t*>(
77 in_ptr + offsets[1]
78 );
79 using primitive_t = typename scalar_value_type<scalar_t>::type;
80 auto* __restrict__ coeff_data = reinterpret_cast<primitive_t*>(
81 coeff_ptr + offsets[2]
82 );
83
84 // perform summation
85 for (int32_t i = 0; i < num_summations; ++i) {
86 *out_data += in_data[i * in_stride] * coeff_data[i * coeff_stride];
87 }
88 };
89
90 _lauch_kernel<num_threads(), thread_work_size()>(iter.numel(), loop);
91 }
92
_compute_linear_combination_cuda_kernel(TensorIterator & iter,int64_t in_stride,int64_t coeff_stride,int64_t num_summations)93 void _compute_linear_combination_cuda_kernel(
94 TensorIterator& iter,
95 int64_t in_stride,
96 int64_t coeff_stride,
97 int64_t num_summations
98 ) {
99 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
100 at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
101 iter.dtype(),
102 "_compute_linear_combination_cuda", [&] () {
103 _compute_linear_combination_internal_kernel<scalar_t>(
104 iter, in_stride, coeff_stride, num_summations
105 );
106 }
107 );
108 }
109
110 }
111
112 REGISTER_DISPATCH(_compute_linear_combination_stub, &_compute_linear_combination_cuda_kernel);
113
114 } // namespace at::native
115