xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Loops.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/detail/FunctionTraits.h>
4 #include <ATen/native/TensorIterator.h>
5 #include <ATen/native/TensorIteratorDynamicCasting.h>
6 #include <ATen/cuda/detail/OffsetCalculator.cuh>
7 #include <ATen/OpMathType.h>
8 #include <ATen/native/cuda/thread_constants.h>
9 
10 #include <thrust/tuple.h>
11 
12 #include <ATen/native/cuda/MemoryAccess.cuh>
13 
14 
15 namespace at { namespace native {
16 
17 template<int N>
make_input_offset_calculator(const TensorIteratorBase & iter)18 static OffsetCalculator<N> make_input_offset_calculator(const TensorIteratorBase& iter) {
19   // array size can not be 0, this happens when N == 0
20   constexpr int array_size = std::max<int>(N, 1);
21   TORCH_INTERNAL_ASSERT(N == iter.ntensors() - iter.noutputs());
22   std::array<const int64_t*, array_size> strides;
23   int64_t element_sizes[array_size];
24   for (int i = 0; i < N; i++) {
25     strides[i] = iter.strides(i + iter.noutputs()).data();
26     element_sizes[i] = iter.element_size(i + iter.noutputs());
27   }
28   return OffsetCalculator<N>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
29 }
30 
31 template <int num_outputs = 1>
make_output_offset_calculator(const TensorIteratorBase & iter)32 static OffsetCalculator<num_outputs> make_output_offset_calculator(const TensorIteratorBase& iter) {
33   TORCH_INTERNAL_ASSERT(num_outputs == iter.noutputs());
34   std::array<const int64_t*, num_outputs> strides;
35   int64_t element_sizes[num_outputs];
36   for (int i = 0; i < num_outputs; i++) {
37     strides[i] = iter.strides(i).data();
38     element_sizes[i] = iter.element_size(i);
39   }
40   return OffsetCalculator<num_outputs>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
41 }
42 
43 template<typename func_t, typename policy_t>
elementwise_kernel_helper(func_t f,policy_t policy)44 __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
45   using traits = function_traits<func_t>;
46   using return_t = typename traits::result_type;
47   using args_t = typename traits::ArgsTuple;
48 
49   int idx = blockIdx.x;
50 
51   return_t results[thread_work_size()];
52   args_t args[thread_work_size()];
53 
54   // load
55   policy.load(args, idx);
56 
57   // compute
58   #pragma unroll
59   for (int i = 0; i < thread_work_size(); i++) {
60     if (policy.check_inbounds(i)) {
61       results[i] = c10::guts::apply(f, args[i]);
62     }
63   }
64 
65   // store
66   policy.store(results, idx);
67 }
68 
69 }}  // namespace at::native
70 
71 #include <ATen/native/cuda/CUDALoops.cuh>
72 
73 namespace at:: native {
74 
75 template <typename func_t>
gpu_kernel_nocast(TensorIteratorBase & iter,const func_t & f)76 void gpu_kernel_nocast(TensorIteratorBase& iter, const func_t& f) {
77 
78   for (int arg = 0; arg < iter.ntensors(); arg++) {
79     TORCH_INTERNAL_ASSERT(
80       iter.device(arg).is_cuda(),
81       "argument ", arg, ": expected a CUDA device but found ", iter.device(arg));
82   }
83 
84   if (iter.numel() == 0) {
85     return;
86   }
87 
88   if (!iter.can_use_32bit_indexing()) {
89     for (auto& sub_iter : iter.with_32bit_indexing()) {
90       gpu_kernel_nocast(sub_iter, f);
91     }
92     return;
93   }
94 
95   gpu_kernel_impl_nocast(iter, f);
96 }
97 
98 template <typename func_t>
gpu_kernel(TensorIteratorBase & iter,const func_t & f)99 void gpu_kernel(TensorIteratorBase& iter, const func_t& f) {
100 
101   for (int arg = 0; arg < iter.ntensors(); arg++) {
102     TORCH_INTERNAL_ASSERT(
103       iter.device(arg).is_cuda(),
104       "argument ", arg, ": expected a CUDA device but found ", iter.device(arg));
105   }
106 
107   if (iter.numel() == 0) {
108     return;
109   }
110 
111   if (!iter.can_use_32bit_indexing()) {
112     for (auto& sub_iter : iter.with_32bit_indexing()) {
113       gpu_kernel(sub_iter, f);
114     }
115     return;
116   }
117 
118   gpu_kernel_impl(iter, f);
119 }
120 
121 template<typename arg1_t, typename arg2_t, typename return_t, typename func_t>
122 struct AUnaryFunctor {
123   using traits = function_traits<func_t>;
124   using opmath_arg1_t = typename traits::template arg<0>::type;
operator ()at::native::AUnaryFunctor125   __device__ return_t operator()(arg2_t b) const {
126     return f(a, b);
127   }
128   // NB: scalar is stored in higher precision!
AUnaryFunctorat::native::AUnaryFunctor129   AUnaryFunctor(func_t f_, opmath_arg1_t a_): f(f_), a(a_) {}
130   private:
131     func_t f;
132     opmath_arg1_t a;
133 };
134 
135 template<typename arg1_t, typename arg2_t, typename return_t, typename func_t>
136 struct BUnaryFunctor {
137   using traits = function_traits<func_t>;
138   using opmath_arg2_t = typename traits::template arg<1>::type;
operator ()at::native::BUnaryFunctor139   __device__ return_t operator()(arg1_t a) const {
140     return f(a, b);
141   }
142   // NB: scalar is stored in higher precision!
BUnaryFunctorat::native::BUnaryFunctor143   BUnaryFunctor(func_t f_, opmath_arg2_t b_): f(f_), b(b_) {}
144   private:
145     func_t f;
146     opmath_arg2_t b;
147 };
148 
149 // Though seemingly noop, this inserts casts from arg1_t to func_t's type
150 // (which may be higher precision), as well as casts to return_t
151 template <typename arg1_t, typename arg2_t, typename return_t, typename func_t>
152 struct BinaryFunctor {
operator ()at::native::BinaryFunctor153   __device__ return_t operator()(arg1_t a, arg2_t b) const {
154     return f(a, b);
155   }
BinaryFunctorat::native::BinaryFunctor156   BinaryFunctor(func_t f_): f(f_) {}
157   private:
158     func_t f;
159 };
160 
161 // Unlike gpu_kernel_with_scalars, this allows you to pass a func_t which
162 // accepts inputs at higher precision (typically opmath_t), but then
163 // ensure that we load from memory at the correct precision (scalar_t)
164 // to avoid expensive loads.  For the whole sordid story see
165 // https://dev-discuss.pytorch.org/t/cuda-loops-case-study-code-generation-vs-templates/302
166 template <typename arg1_t, typename arg2_t = arg1_t, typename return_t = arg1_t, typename func_t>
opmath_gpu_kernel_with_scalars(TensorIteratorBase & iter,const func_t & f)167 void opmath_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) {
168   TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
169 
170   using traits = function_traits<func_t>;
171   using opmath_arg1_t = typename traits::template arg<0>::type;
172   using opmath_arg2_t = typename traits::template arg<1>::type;
173   static_assert(
174       traits::arity == 2,
175       "gpu_kernel_with_scalars only supports two input arguments");
176 
177   if (iter.is_cpu_scalar(1)) {
178     AUnaryFunctor<arg1_t, arg2_t, return_t, func_t> af(f, iter.scalar_value<opmath_arg1_t>(1));
179     iter.remove_operand(1);
180     // TODO: When all kernels that use gpu_kernel_with_scalars are
181     // ported to structured, this device guard can be deleted.  This
182     // works around incorrect device guard generation for pre-structured
183     // kernels device guards, but structured kernels do it right and
184     // we can assume the device is already set correctly
185     const OptionalDeviceGuard device_guard(iter.device(1));
186     gpu_kernel(iter, af);
187   } else if (iter.is_cpu_scalar(2)) {
188     BUnaryFunctor<arg1_t, arg2_t, return_t, func_t> bf(f, iter.scalar_value<opmath_arg2_t>(2));
189     iter.remove_operand(2);
190     gpu_kernel(iter, bf);
191   } else {
192     gpu_kernel(iter, BinaryFunctor<arg1_t, arg2_t, return_t, func_t>(f));
193   }
194 }
195 
196 template <typename scalar_t, typename return_t = scalar_t, typename func_t>
opmath_symmetric_gpu_kernel_with_scalars(TensorIteratorBase & iter,const func_t & f)197 void opmath_symmetric_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) {
198   // Use symmetric property of the functor to reduce number of kernels,
199   // requires f(a, b) == f(b, a)
200   TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
201 
202   using traits = function_traits<func_t>;
203   using opmath_arg_t = typename traits::template arg<0>::type;
204   static_assert(
205       traits::arity == 2,
206       "gpu_kernel_with_scalars only supports two input arguments");
207   static_assert(std::is_same<opmath_arg_t, typename traits::template arg<1>::type>::value,
208                 "f is not symmetric");
209 
210   OptionalDeviceGuard device_guard;
211   opmath_arg_t scalar_val{};
212 
213   if (iter.is_cpu_scalar(1)) {
214     scalar_val = iter.scalar_value<opmath_arg_t>(1);
215     iter.remove_operand(1);
216 
217     // TODO: When all kernels that use gpu_kernel_with_scalars are
218     // ported to structured, this device guard can be deleted.  This
219     // works around incorrect device guard generation for pre-structured
220     // kernels device guards, but structured kernels do it right and
221     // we can assume the device is already set correctly
222     device_guard.reset_device(iter.device(1));
223   } else if (iter.is_cpu_scalar(2)) {
224     scalar_val = iter.scalar_value<opmath_arg_t>(2);
225     iter.remove_operand(2);
226   }
227 
228   if (iter.ninputs() == 2) {
229     gpu_kernel(iter, BinaryFunctor<scalar_t, scalar_t, return_t, func_t>(f));
230   } else {
231     AUnaryFunctor<scalar_t, scalar_t, return_t, func_t> unary_f(f, scalar_val);
232     gpu_kernel(iter, unary_f);
233   }
234 }
235 
236 // Legacy variant that assumes that func_t has the correct types
237 // that we expect to load from memory
238 template <typename func_t>
gpu_kernel_with_scalars(TensorIteratorBase & iter,const func_t & f)239 void gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) {
240   using traits = function_traits<func_t>;
241   static_assert(
242       traits::arity == 2,
243       "gpu_kernel_with_scalars only supports two input arguments");
244   using arg1_t = typename traits::template arg<0>::type;
245   using arg2_t = typename traits::template arg<1>::type;
246   using return_t = typename traits::result_type;
247   opmath_gpu_kernel_with_scalars<arg1_t, arg2_t, return_t, func_t>(iter, f);
248 }
249 
250 namespace { // functions for `gpu_kernel_multiple_outputs`.
251 
252 // check the return type is `thrust::tuple`, not `std::tuple`.
253 template <typename T> struct is_tuple: std::false_type {};
254 
255 template <typename ...T> struct is_tuple<thrust::tuple<T...>>: std::true_type {};
256 
257 template <int num_outputs, typename func_t, typename array_t, typename inp_calc_t, typename out_calc_t>
C10_LAUNCH_BOUNDS_1(num_threads ())258 C10_LAUNCH_BOUNDS_1(num_threads())
259 __global__ void unrolled_elementwise_kernel_for_multi_outputs(int N, func_t f, array_t data, inp_calc_t ic, out_calc_t oc) {
260   int remaining = N - block_work_size() * blockIdx.x;
261   elementwise_kernel_helper(f, memory::policies::multi_outputs_unroll<array_t, inp_calc_t, out_calc_t, num_outputs>(data, remaining, ic, oc));
262 }
263 
264 template <int num_outputs, typename func_t, typename array_t, typename inp_calc_t, typename out_calc_t>
launch_unrolled_kernel_for_multi_outputs(int64_t N,const func_t & f,array_t data,inp_calc_t ic,out_calc_t oc)265 static inline void launch_unrolled_kernel_for_multi_outputs(int64_t N, const func_t& f, array_t data, inp_calc_t ic, out_calc_t oc) {
266   TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
267   int64_t grid = (N + block_work_size() - 1) / block_work_size();
268   auto stream = at::cuda::getCurrentCUDAStream();
269   unrolled_elementwise_kernel_for_multi_outputs<num_outputs, func_t, array_t><<<grid, num_threads(), 0, stream>>>(N, f, data, ic, oc);
270   C10_CUDA_KERNEL_LAUNCH_CHECK();
271 }
272 
273 template <typename func_t>
gpu_kernel_multiple_outputs_impl(TensorIteratorBase & iter,const func_t & f)274 void gpu_kernel_multiple_outputs_impl(TensorIteratorBase& iter, const func_t& f) {
275   using traits = function_traits<func_t>;
276   using output_t = typename traits::result_type;
277   static_assert(is_tuple<output_t>::value, "f's return type must be `thrust::tuple`");
278   constexpr int num_outputs = thrust::tuple_size<output_t>::value;
279   constexpr int num_inputs = traits::arity;
280   constexpr int ntensors = num_outputs + num_inputs;
281 
282   TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
283   TORCH_INTERNAL_ASSERT(iter.ntensors() == ntensors);
284 
285   at::detail::Array<char*, ntensors> data;
286   for (int i = 0; i < ntensors; i++) {
287     data[i] = (char*)iter.data_ptr(i);
288   }
289 
290   int64_t numel = iter.numel();
291 
292   if (iter.is_contiguous()) {
293     auto input_calc = TrivialOffsetCalculator<num_inputs>();
294     auto output_calc = TrivialOffsetCalculator<num_outputs>();
295     launch_unrolled_kernel_for_multi_outputs<num_outputs>(numel, f, data, input_calc, output_calc);
296   } else {
297     auto input_calc = make_input_offset_calculator<num_inputs>(iter);
298     auto output_calc = make_output_offset_calculator<num_outputs>(iter);
299     launch_unrolled_kernel_for_multi_outputs<num_outputs>(numel, f, data, input_calc, output_calc);
300   }
301 }
302 } // namespace
303 
304 template <typename func_t>
gpu_kernel_multiple_outputs(TensorIteratorBase & iter,const func_t & f)305 void gpu_kernel_multiple_outputs(TensorIteratorBase& iter, const func_t& f) {
306   ASSERT_HOST_DEVICE_LAMBDA(func_t);
307 
308   for (int arg = 0; arg < iter.ntensors(); arg++) {
309     TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda());
310   }
311 
312   if (iter.numel() == 0) {
313     return;
314   }
315 
316   if (!iter.can_use_32bit_indexing()) {
317     for (auto& sub_iter : iter.with_32bit_indexing()) {
318       gpu_kernel_multiple_outputs(sub_iter, f);
319     }
320     return;
321   }
322 
323   gpu_kernel_multiple_outputs_impl(iter, f);
324 }
325 
326 } //namespace at::native
327