xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/CUDALoops.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // This file provides two functions to help write GPU elementwise kernels:
4 //
5 //   gpu_kernel(TensorIterator iter, <lambda>)
6 //   gpu_kernel_with_scalars(TensorIterator iter, <lambda>)
7 //
8 // The gpu_kernel_with_scalars generates specializations that support a
9 // single scalar CPU argument, such as from `cuda_tensor + 5`. The CPU scalar
10 // is lifted to a kernel parameter instead of copying to device memory.
11 // This should be  used in conjunction with TensorIterator::allow_cpu_scalars_,
12 // which is the default for TensorIterator::binary_op. Otherwise, all inputs
13 // and the output must be on the GPU.
14 //
15 // For example, to write a reciprocal kernel for GPU float Tensors:
16 //
17 //   gpu_kernel(iter, []GPU_LAMBDA(float a) {
18 //    return 1.0f / a;
19 //   });
20 //
21 // To write a multiplication kernel for GPU float Tensors where one argument
22 // may be a CPU scalar:
23 //
24 //   gpu_kernel_with_scalars(iter, []GPU_LAMBDA(float a, float b) {
25 //     return a * b;
26 //   });
27 //
28 // See BinaryOpsKernel.cu for the complete implementation
29 //
30 
31 #include <iostream>
32 #include <tuple>
33 #include <type_traits>
34 
35 #include <ATen/core/Array.h>
36 #include <ATen/cuda/CUDAContext.h>
37 #include <ATen/detail/FunctionTraits.h>
38 #include <ATen/native/TensorIterator.h>
39 #include <c10/core/DynamicCast.h>
40 #include <c10/core/ScalarType.h>
41 #include <c10/macros/Macros.h>
42 #include <c10/util/TypeCast.h>
43 
44 #ifdef __NVCC__
45 #define ASSERT_HOST_DEVICE_LAMBDA(type)                       \
46   static_assert(                                              \
47       __nv_is_extended_host_device_lambda_closure_type(type), \
48       #type " must be a __host__ __device__ lambda")
49 #else
50 #define ASSERT_HOST_DEVICE_LAMBDA(type)
51 #endif
52 
53 namespace at {
54 namespace native {
55 
56 template <int vec_size, typename func_t, typename array_t>
C10_LAUNCH_BOUNDS_1(num_threads ())57 C10_LAUNCH_BOUNDS_1(num_threads())
58 __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
59   using traits = function_traits<func_t>;
60   int remaining = N - block_work_size() * blockIdx.x;
61 
62   if (remaining < block_work_size()) { // if this block handles the reminder,
63                                        // just do a naive unrolled loop
64     auto input_calc = TrivialOffsetCalculator<traits::arity>();
65     auto output_calc = TrivialOffsetCalculator<1>();
66     auto loader = memory::LoadWithoutCast();
67     auto storer = memory::StoreWithoutCast();
68     auto policy = memory::policies::unroll<
69         array_t,
70         decltype(input_calc),
71         decltype(output_calc),
72         memory::LoadWithoutCast,
73         memory::StoreWithoutCast>(
74         data, remaining, input_calc, output_calc, loader, storer);
75     elementwise_kernel_helper(f, policy);
76   } else { // if this block has a full `block_work_size` data to handle, use
77            // vectorized memory access
78     elementwise_kernel_helper(
79         f, memory::policies::vectorized<vec_size, array_t>(data));
80   }
81 }
82 
83 template <
84     typename func_t,
85     typename array_t,
86     typename inp_calc_t,
87     typename out_calc_t,
88     typename loader_t,
89     typename storer_t>
C10_LAUNCH_BOUNDS_1(num_threads ())90 C10_LAUNCH_BOUNDS_1(num_threads())
91 __global__ void unrolled_elementwise_kernel(
92     int N,
93     func_t f,
94     array_t data,
95     inp_calc_t ic,
96     out_calc_t oc,
97     loader_t l,
98     storer_t s) {
99   int remaining = N - block_work_size() * blockIdx.x;
100   auto policy = memory::policies::
101       unroll<array_t, inp_calc_t, out_calc_t, loader_t, storer_t>(
102           data, remaining, ic, oc, l, s);
103   elementwise_kernel_helper(f, policy);
104 }
105 
106 // this function assume trivial 1d and no dynamic casting
107 template <typename func_t, typename array_t>
launch_vectorized_kernel(int64_t N,const func_t & f,array_t data)108 static inline void launch_vectorized_kernel(
109     int64_t N,
110     const func_t& f,
111     array_t data) {
112   TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
113   using traits = function_traits<func_t>;
114   int64_t grid = (N + block_work_size() - 1) / block_work_size();
115   auto stream = at::cuda::getCurrentCUDAStream();
116   int vec_size = memory::can_vectorize_up_to<func_t>(data);
117 
118   switch (vec_size) {
119     case 4:
120       vectorized_elementwise_kernel<4, func_t, array_t>
121           <<<grid, num_threads(), 0, stream>>>(N, f, data);
122       C10_CUDA_KERNEL_LAUNCH_CHECK();
123       break;
124     case 2:
125       vectorized_elementwise_kernel<2, func_t, array_t>
126           <<<grid, num_threads(), 0, stream>>>(N, f, data);
127       C10_CUDA_KERNEL_LAUNCH_CHECK();
128       break;
129     case 1: {
130       auto input_calc = TrivialOffsetCalculator<traits::arity>();
131       auto output_calc = TrivialOffsetCalculator<1>();
132       auto loader = memory::LoadWithoutCast();
133       auto storer = memory::StoreWithoutCast();
134       unrolled_elementwise_kernel<func_t, array_t>
135           <<<grid, num_threads(), 0, stream>>>(
136               N, f, data, input_calc, output_calc, loader, storer);
137       C10_CUDA_KERNEL_LAUNCH_CHECK();
138       break;
139     }
140     default:
141       TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size");
142   }
143 }
144 
145 template <
146     typename func_t,
147     typename array_t,
148     typename inp_calc_t,
149     typename out_calc_t,
150     typename loader_t,
151     typename storer_t>
launch_unrolled_kernel(int64_t N,const func_t & f,array_t data,inp_calc_t ic,out_calc_t oc,loader_t l,storer_t s)152 static inline void launch_unrolled_kernel(
153     int64_t N,
154     const func_t& f,
155     array_t data,
156     inp_calc_t ic,
157     out_calc_t oc,
158     loader_t l,
159     storer_t s) {
160   TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
161   int64_t grid = (N + block_work_size() - 1) / block_work_size();
162   auto stream = at::cuda::getCurrentCUDAStream();
163   unrolled_elementwise_kernel<func_t, array_t>
164       <<<grid, num_threads(), 0, stream>>>(N, f, data, ic, oc, l, s);
165   C10_CUDA_KERNEL_LAUNCH_CHECK();
166 }
167 
168 template <int nt, int vt, typename func_t>
169 C10_LAUNCH_BOUNDS_2(nt, 4)
elementwise_kernel(int N,func_t f)170 __global__ void elementwise_kernel(int N, func_t f) {
171   int tid = threadIdx.x;
172   int nv = nt * vt;
173   int idx = nv * blockIdx.x + tid;
174 #pragma unroll
175   for (int i = 0; i < vt; i++) {
176     if (idx < N) {
177       f(idx);
178       idx += nt;
179     }
180   }
181 }
182 
183 template <int nt, int vt, typename func_t>
launch_legacy_kernel(int64_t N,const func_t & f)184 static void launch_legacy_kernel(int64_t N, const func_t& f) {
185   TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
186   if (N == 0) {
187     return;
188   }
189   dim3 block(nt);
190   dim3 grid((N + block.x * vt - 1) / (block.x * vt));
191   auto stream = at::cuda::getCurrentCUDAStream();
192   elementwise_kernel<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
193   C10_CUDA_KERNEL_LAUNCH_CHECK();
194 }
195 
196 template <typename traits, typename func_t, typename index_t, size_t... INDEX>
invoke_impl(const func_t & f,char * const C10_RESTRICT data[],const index_t strides[],int i,std::index_sequence<INDEX...>)197 C10_HOST_DEVICE typename traits::result_type invoke_impl(
198     const func_t& f,
199     char* const C10_RESTRICT data[],
200     const index_t strides[],
201     int i,
202     std::index_sequence<INDEX...>) {
203   (void)strides;
204   (void)i;
205   return f(c10::load<typename traits::template arg<INDEX>::type>(
206       data[INDEX] + i * strides[INDEX])...);
207 }
208 
209 template <
210     typename func_t,
211     typename index_t,
212     typename traits = function_traits<func_t>>
invoke(const func_t & f,char * const C10_RESTRICT data[],const index_t strides[],int i)213 C10_HOST_DEVICE typename traits::result_type invoke(
214     const func_t& f,
215     char* const C10_RESTRICT data[],
216     const index_t strides[],
217     int i) {
218   using Indices = std::make_index_sequence<traits::arity>;
219   return invoke_impl<traits>(f, data, strides, i, Indices{});
220 }
221 
222 template <typename traits, typename func_t, typename index_t, size_t... I>
invoke_impl(const func_t & f,char * const C10_RESTRICT data[],const index_t strides[],const ScalarType dtypes[],int i,std::index_sequence<I...>)223 C10_HOST_DEVICE typename traits::result_type invoke_impl(
224     const func_t& f,
225     char* const C10_RESTRICT data[],
226     const index_t strides[],
227     const ScalarType dtypes[],
228     int i,
229     std::index_sequence<I...>) {
230   (void)strides;
231   (void)i;
232   return f(c10::fetch_and_cast<typename traits::template arg<I>::type>(
233       dtypes[I], data[I] + i * strides[I])...);
234 }
235 
236 template <
237     typename func_t,
238     typename index_t,
239     typename traits = function_traits<func_t>>
invoke(const func_t & f,char * const C10_RESTRICT data[],const index_t strides[],const ScalarType dtypes[],int i)240 C10_HOST_DEVICE typename traits::result_type invoke(
241     const func_t& f,
242     char* const C10_RESTRICT data[],
243     const index_t strides[],
244     const ScalarType dtypes[],
245     int i) {
246   using Indices = std::make_index_sequence<traits::arity>;
247   return invoke_impl<traits>(f, data, strides, dtypes, i, Indices{});
248 }
249 
250 template <typename func_t>
gpu_kernel_impl_nocast(TensorIteratorBase & iter,const func_t & f)251 void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
252   using traits = function_traits<func_t>;
253   using arg0_t = typename traits::result_type;
254   constexpr int ntensors = traits::arity + 1;
255 
256   TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
257   TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
258   TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
259   TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
260 
261   at::detail::Array<char*, ntensors> data;
262   for (int i = 0; i < ntensors; i++) {
263     data[i] = (char*)iter.data_ptr(i);
264   }
265 
266   int64_t numel = iter.numel();
267 
268   bool contiguous = iter.is_contiguous();
269 
270   if (contiguous) {
271     return launch_vectorized_kernel(numel, f, data);
272   }
273   auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
274   constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4;
275   launch_legacy_kernel<128, unroll_factor>(numel, [=] GPU_LAMBDA(int idx) {
276     auto offsets = offset_calc.get(idx);
277     arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
278     *out = invoke(f, &data.data[1], &offsets.data[1], 1);
279   });
280 }
281 
282 template <typename func_t>
gpu_kernel_impl(TensorIteratorBase & iter,const func_t & f)283 void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
284   if (!needs_dynamic_casting<func_t>::check(iter)) {
285     return gpu_kernel_impl_nocast(iter, f);
286   }
287   using traits = function_traits<func_t>;
288   using arg0_t = typename traits::result_type;
289   constexpr int ntensors = traits::arity + 1;
290 
291   TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
292   TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
293   TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
294 
295   at::detail::Array<char*, ntensors> data;
296   for (int i = 0; i < ntensors; i++) {
297     data[i] = (char*)iter.data_ptr(i);
298   }
299 
300   int64_t numel = iter.numel();
301 
302   bool contiguous = iter.is_contiguous();
303 
304   if (contiguous) {
305 #ifdef USE_ROCM
306     at::detail::Array<ScalarType, ntensors> dtypes;
307     auto inner_strides = iter.get_inner_strides();
308     at::detail::Array<int, ntensors> strides;
309     for (int i = 0; i < ntensors; i++) {
310       dtypes[i] = iter.dtype(i);
311       strides[i] = inner_strides[i];
312     }
313     launch_legacy_kernel<512, 1>(numel, [=]GPU_LAMBDA(int idx) {
314       void* out = data[0] + strides[0] * idx;
315       arg0_t result = invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
316       c10::cast_and_store<arg0_t>(dtypes[0], out, result);
317     });
318 #else
319     auto loader = memory::LoadWithCast<traits::arity>(iter);
320     auto storer = memory::StoreWithCast<1>(iter);
321     auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
322     auto output_offset_calculator = TrivialOffsetCalculator<1>();
323     launch_unrolled_kernel(
324         numel,
325         f,
326         data,
327         input_offset_calculator,
328         output_offset_calculator,
329         loader,
330         storer);
331 #endif
332   } else {
333     at::detail::Array<ScalarType, ntensors> dtypes;
334     for (int i = 0; i < ntensors; i++) {
335       dtypes[i] = iter.dtype(i);
336     }
337     auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
338     launch_legacy_kernel<128, 4>(numel, [=] GPU_LAMBDA(int idx) {
339       auto offsets = offset_calc.get(idx);
340       void* out = data[0] + offsets[0];
341       arg0_t result = invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
342       c10::cast_and_store<arg0_t>(dtypes[0], out, result);
343     });
344   }
345 }
346 
347 } // namespace native
348 } // namespace at
349