xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/jiterator.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/jit_macros.h>
2 
3 #if AT_USE_JITERATOR()
4 
5 #include <c10/cuda/CUDAGuard.h>
6 #include <ATen/cuda/jiterator.h>
7 #include <ATen/cuda/jiterator_impl.h>
8 
9 #include <iostream>
10 #include <utility>
11 #include <chrono>
12 namespace at {
13 namespace native {
14 
launch_jitted_vectorized_kernel_dynamic(const std::string & name,TensorIteratorBase & iter,DeviceIndex dev_idx,int64_t N,const std::string & f,void * data_ptr,const c10::SmallVector<at::Scalar> & extra_args,bool return_by_ref)15 static inline void launch_jitted_vectorized_kernel_dynamic(
16   const std::string& name, TensorIteratorBase& iter,
17   DeviceIndex dev_idx, int64_t N, const std::string& f, void* data_ptr,
18   const c10::SmallVector<at::Scalar>& extra_args, bool return_by_ref) {
19   TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
20   // N is still int64_t for the computation, but it's always safe to cast result to int
21   const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
22 
23   const int vec_size = jitted_can_vectorize_up_to(iter);
24   bool vectorized = vec_size > 1;
25 
26   // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
27   //   fn_ptr is set to the appropriate function based on the vec size and GPU used
28   // TODO: Memory use can probably be optimized by re-using kernels across GPUs with
29   //   the same compute capability
30 
31   int nInputs = iter.ninputs();
32   int nOutputs = iter.noutputs();
33   const at::ScalarType common_dtype = iter.common_dtype();
34   std::string f_inputs_type_str = at::cuda::jit::typeName(common_dtype);
35   std::string compute_type_str = at::cuda::jit::typeName(toOpMathType(common_dtype));
36   std::string result_type_str = at::cuda::jit::typeName(common_dtype);
37   c10::SmallVector<std::string> extra_args_types = get_extra_args_typenames(extra_args);
38 
39   // The cache key includes all the parameters to generate_code + vec_size + dev_idx
40   std::stringstream ss;
41   ss << nInputs << "_" << nOutputs << f;
42   ss << f_inputs_type_str << compute_type_str << result_type_str;
43   ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar);
44   ss << extra_args_types;
45   ss << vec_size;
46 // DeviceIndex, e.g. int8_t, is not treated as a number by the stream, cast to int as a workaround
47   ss << static_cast<int>(dev_idx);
48   const std::string cache_key = ss.str();
49 
50   static std::mutex _jiterator_mutex;
51   static std::unordered_map<std::string, at::cuda::jit::NvrtcFunction> fns;
52   at::cuda::jit::NvrtcFunction* fn_ptr = &fns[cache_key];
53 
54   if (!fn_ptr->function) {
55     const std::lock_guard<std::mutex> lock{_jiterator_mutex};
56     if (!fn_ptr->function) { // cache miss!
57       // Generates program
58       auto code = at::cuda::jit::generate_code(nInputs, nOutputs, f, name,
59                                                f_inputs_type_str, compute_type_str, result_type_str,
60                                                /*contiguous=*/true, /*dynamic_casting=*/false,
61                                                at::cuda::jit::BinaryFuncVariant::NoScalar,
62                                                extra_args_types,
63                                                vectorized, vec_size,
64                                                return_by_ref);
65       std::string kernel_name = vectorized ? name + "_vectorized" + std::to_string(vec_size) : name;
66       // Acquires the program
67       *fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name);
68     }
69   }
70 
71   // size of `extra_args` is unknown at compile-time
72   auto extra_args_size = extra_args.size();
73 
74   float scalar_val = 0;
75 
76   if (vectorized) {
77     // pack args for kernel launch
78     constexpr int kernel_args = 3;
79     auto args = std::make_unique<void*[]>(kernel_args + extra_args_size);
80     args[0] = static_cast<void*>(&N);
81     args[1] = data_ptr;
82     args[2] = static_cast<void*>(&scalar_val);
83 
84     for (const auto i : c10::irange(extra_args_size)) {
85       // since 3 slots are already filled in `args`
86       args[i + 3] = const_cast<void*>(extra_args[i].data_ptr());
87     }
88     at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args.get(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
89   } else {
90     TrivialOffsetCalculatorVariant input_offset_calculator(iter.ninputs());
91     void* ic_ptr = input_offset_calculator.data_ptr();
92     TrivialOffsetCalculatorVariant output_offset_calculator(iter.noutputs());
93     void* oc_ptr = output_offset_calculator.data_ptr();
94 
95     auto l = memory::LoadWithoutCast();
96     auto s = memory::StoreWithoutCast();
97 
98     // pack args for kernel launch
99     constexpr int kernel_args = 7;
100     auto args = std::make_unique<void*[]>(kernel_args + extra_args_size);
101     args[0] = static_cast<void*>(&N);
102     args[1] = data_ptr;
103     args[2] = ic_ptr;
104     args[3] = oc_ptr;
105     args[4] = static_cast<void*>(&l);
106     args[5] = static_cast<void*>(&s);
107     args[6] = static_cast<void*>(&scalar_val);
108 
109     for (const auto i : c10::irange(extra_args_size)) {
110       // since 7 slots are already filled in `args`
111       args[i + 7] = const_cast<void*>(extra_args[i].data_ptr());
112     }
113 
114     at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args.get(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
115   }
116 }
117 
launch_jitted_unrolled_kernel_dynamic(const std::string & name,TensorIteratorBase & iter,DeviceIndex dev_idx,int64_t N,const std::string & f,void * data_ptr,void * ic_ptr,void * oc_ptr,void * l_ptr,void * s_ptr,bool contiguous,bool dynamic_casting,const c10::SmallVector<at::Scalar> & extra_args,bool return_by_ref)118 static inline void launch_jitted_unrolled_kernel_dynamic(
119   const std::string& name, TensorIteratorBase& iter,
120   DeviceIndex dev_idx, int64_t N, const std::string& f, void* data_ptr,
121   void* ic_ptr, void* oc_ptr, void* l_ptr, void* s_ptr, bool contiguous, bool dynamic_casting,
122   const c10::SmallVector<at::Scalar>& extra_args, bool return_by_ref) {
123 
124   TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
125   //casting result to int is always safe, intermediate is int64 and won't overflow
126   const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
127 
128   int nInputs = iter.ninputs();
129   int nOutputs = iter.noutputs();
130   const at::ScalarType common_dtype = iter.common_dtype();
131   std::string f_inputs_type_str = at::cuda::jit::typeName(common_dtype);
132   std::string compute_type_str = at::cuda::jit::typeName(toOpMathType(common_dtype));
133   std::string result_type_str = at::cuda::jit::typeName(common_dtype);
134   c10::SmallVector<std::string> extra_args_types = get_extra_args_typenames(extra_args);
135 
136   // The cache key includes all the parameters to generate_code + dev_idx
137   std::stringstream ss;
138   ss << nInputs << "_" << nOutputs << f;
139   ss << f_inputs_type_str << compute_type_str << result_type_str;
140   ss << contiguous << dynamic_casting;
141   ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar);
142   ss << extra_args_types;
143   ss << dev_idx;
144   const std::string cache_key = ss.str();
145 
146   static std::mutex _jiterator_mutex;
147   static std::unordered_map<std::string, at::cuda::jit::NvrtcFunction> fns;
148 
149   at::cuda::jit::NvrtcFunction* fn_ptr = &fns[cache_key];
150   if (!fn_ptr->function) {
151     const std::lock_guard<std::mutex> lock{_jiterator_mutex};
152     if (!fn_ptr->function) {
153       auto code = at::cuda::jit::generate_code(nInputs, nOutputs, f, name,
154                                                f_inputs_type_str, compute_type_str, result_type_str,
155                                                contiguous, dynamic_casting,
156                                                at::cuda::jit::BinaryFuncVariant::NoScalar,
157                                                extra_args_types, /*vectorized*/false, /*vec_size*/0, return_by_ref);
158       *fn_ptr = at::cuda::jit::jit_pwise_function(code, name);
159     }
160   }
161 
162   float scalar_val = 0;
163 
164   // pack args for kernel launch
165   constexpr int kernel_args = 7;
166   auto extra_args_size = extra_args.size();
167   auto args = std::make_unique<void*[]>(kernel_args + extra_args_size);
168   args[0] = static_cast<void*>(&N);
169   args[1] = data_ptr;
170   args[2] = ic_ptr;
171   args[3] = oc_ptr;
172   args[4] = l_ptr;
173   args[5] = s_ptr;
174   args[6] = static_cast<void*>(&scalar_val);
175 
176   for (const auto i : c10::irange(extra_args_size)) {
177     // since 7 slots are already filled in `args`
178     args[i + 7] = const_cast<void*>(extra_args[i].data_ptr());
179   }
180 
181   at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args.get(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
182 }
183 
jitted_gpu_kernel_dynamic_impl(const std::string & kernel_name,TensorIteratorBase & iter,const std::string & f,const bool dynamic_casting,const c10::SmallVector<at::Scalar> & extra_args,bool return_by_ref)184 void jitted_gpu_kernel_dynamic_impl(
185     const std::string& kernel_name,
186     TensorIteratorBase& iter,
187     const std::string& f,
188     const bool dynamic_casting,
189     const c10::SmallVector<at::Scalar>& extra_args,
190     bool return_by_ref) {
191 
192   TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
193   TORCH_INTERNAL_ASSERT(iter.noutputs() <= 8);
194   TORCH_INTERNAL_ASSERT(iter.ninputs() <= 8);
195 
196   ArrayVariant data(iter);
197   void* data_ptr = data.data_ptr();
198 
199   int64_t numel = iter.numel();
200   bool contiguous = iter.is_contiguous();
201 
202   // Decides which of 4 kernel types to launch
203   // Variations are:
204   //   - Case 1: no dynamic casting and contiguous
205   //   - Case 2: no dynamic casting and noncontiguous
206   //   - Case 3: dynamic casting and contiguous
207   //   - Case 4: dynamic casting and noncontiguous
208   // These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl
209 
210   if (!dynamic_casting) {
211     if (contiguous) {
212       // Case 1: no dynamic casting and contiguous
213       launch_jitted_vectorized_kernel_dynamic(kernel_name, iter,
214          iter.device().index(), numel, f, data_ptr, extra_args, return_by_ref);
215       return;
216     }
217 
218     // Case 2: no dynamic casting and noncontiguous
219     OffsetCalculatorVariant</*is_input=*/true> input_offset_calculator(iter);
220     void* ic_ptr = input_offset_calculator.data_ptr();
221     OffsetCalculatorVariant</*is_input=*/false> output_offset_calculator(iter);
222     void* oc_ptr = output_offset_calculator.data_ptr();
223 
224     auto loader = memory::LoadWithoutCast();
225     auto storer = memory::StoreWithoutCast();
226     void* l_ptr = static_cast<void*>(&loader);
227     void* s_ptr = static_cast<void*>(&storer);
228 
229     launch_jitted_unrolled_kernel_dynamic(
230       kernel_name, iter, iter.device().index(), numel, f, data_ptr,
231       ic_ptr, oc_ptr, l_ptr, s_ptr, contiguous, dynamic_casting, extra_args, return_by_ref);
232 
233     return;
234   }
235 
236   // Cases 3 and 4 are handled below
237   // Both require construction of one or more storers and loaders
238 
239   // Creates load casts from inputs (note offset indexing into the iterators noutpus...n tensors)
240   LoadWithCastVariant loader(iter);
241   void* l_ptr = loader.data_ptr();
242 
243   // Creates store cast to output (the 0...noutpus-1 tensor in TensorIterator)
244   StoreWithCastVariant storer(iter);
245   void* s_ptr = storer.data_ptr();
246 
247   if (contiguous) {
248     // Case 3: dynamic casting and contiguous
249     TrivialOffsetCalculatorVariant input_offset_calculator(iter.ninputs());
250     void* ic_ptr = input_offset_calculator.data_ptr();
251     TrivialOffsetCalculatorVariant output_offset_calculator(iter.noutputs());
252     void* oc_ptr = output_offset_calculator.data_ptr();
253 
254     launch_jitted_unrolled_kernel_dynamic(
255       kernel_name, iter, iter.device().index(), numel, f, data_ptr,
256       ic_ptr, oc_ptr, l_ptr, s_ptr, contiguous, dynamic_casting, extra_args, return_by_ref);
257     return;
258   }
259 
260   // Case 4: dynamic casting and noncontiguous
261   OffsetCalculatorVariant</*is_input=*/true> input_offset_calculator(iter);
262   void* ic_ptr = input_offset_calculator.data_ptr();
263   OffsetCalculatorVariant</*is_input=*/false> output_offset_calculator(iter);
264   void* oc_ptr = output_offset_calculator.data_ptr();
265 
266   launch_jitted_unrolled_kernel_dynamic(
267       kernel_name, iter, iter.device().index(), numel, f, data_ptr,
268       ic_ptr, oc_ptr, l_ptr, s_ptr, contiguous, dynamic_casting, extra_args, return_by_ref);
269 }
270 
271 // Entrypoint for dynamic version of jitted GPU kernels, which accepts dynamic number of inputs
272 // and arbitrary types of input and extra args. This dynamic version is needed for jiterator with python interface,
273 // since the kernel definition is unknown at the compilation time.
274 // Similarly, launch_jitted_vectorized_kernel_dynamic and launch_jitted_unrolled_kernel_dynamic are created
275 // to handle arbitrary functions defined in python user code.
276 // For templated version, see note [Jiterator] in JitLoops.cuh for more details
jitted_gpu_kernel_dynamic(const std::string & kernel_name,TensorIteratorBase & iter,const std::string & f,const c10::SmallVector<at::Scalar> & extra_args,bool return_by_ref)277 void jitted_gpu_kernel_dynamic(
278     const std::string& kernel_name,
279     TensorIteratorBase& iter,
280     const std::string& f,
281     const c10::SmallVector<at::Scalar>& extra_args,
282     bool return_by_ref) {
283 
284   // TODO: much of preamble is common to both jitted_gpu_kernel and gpu_kernel
285   //   Maybe it could be refactored?
286   for (int arg = 0; arg < iter.ntensors(); arg++) {
287     TORCH_INTERNAL_ASSERT(
288       iter.device(arg).is_cuda(),
289       "argument ", arg, ": expected a CUDA device but found ", iter.device(arg));
290   }
291 
292   if (iter.numel() == 0) {
293     return;
294   }
295 
296   if (!iter.can_use_32bit_indexing()) {
297     for (auto& sub_iter : iter.with_32bit_indexing()) {
298       jitted_gpu_kernel_dynamic(kernel_name, sub_iter, f, extra_args, return_by_ref);
299     }
300     return;
301   }
302 
303   // Computes if dynamic casting is needed
304   // Dynamic casting is needed if an input's or output's dtype differs from the common dtype
305   bool needs_dynamic_casting = false;
306   const at::ScalarType common_dtype = iter.common_dtype();
307   for (auto i = 0; i < iter.ntensors(); ++i) {
308     if (iter.dtype(i) != common_dtype) {
309       needs_dynamic_casting = true;
310       break;
311     }
312   }
313 
314   jitted_gpu_kernel_dynamic_impl(kernel_name, iter, f, needs_dynamic_casting, extra_args, return_by_ref);
315 }
316 
317 } // namespace native
318 
319 namespace cuda {
320 
CompileAndLaunchKernel(const std::string & code_string,const std::string & kernel_name,const int num_outputs,const c10::SmallVector<at::Tensor> & tensors,const c10::SmallVector<at::Scalar> & extra_args,bool return_by_ref)321 c10::SmallVector<at::Tensor> CompileAndLaunchKernel(
322   const std::string& code_string,
323   const std::string& kernel_name,
324   const int num_outputs,
325   const c10::SmallVector<at::Tensor>& tensors,
326   const c10::SmallVector<at::Scalar>& extra_args,
327   bool return_by_ref) {
328 
329   c10::SmallVector<at::Tensor> outs(num_outputs);
330   TensorIteratorConfig config;
331   config
332     .set_check_mem_overlap(true)
333     .allow_cpu_scalars(false)
334     .promote_inputs_to_common_dtype(true)
335     .cast_common_dtype_to_outputs(true)
336     .enforce_safe_casting_to_output(true)
337     .check_all_same_device(true);
338   for (int i = 0; i < num_outputs; ++i) {
339     config.add_owned_output(outs[i]);
340   }
341   for (const auto& t: tensors) {
342     config.add_const_input(t);
343   }
344   TensorIterator iter = config.build();
345 
346   CUDAGuard guard(iter.device());
347   at::native::jitted_gpu_kernel_dynamic(kernel_name, iter, code_string, extra_args, return_by_ref);
348 
349   c10::SmallVector<at::Tensor> outputs;
350   for (int i = 0; i < num_outputs; ++i) {
351     outputs.emplace_back(iter.output(i));
352   }
353 
354   return outputs;
355 }
356 
357 }} // namespace at::cuda
358 
359 #endif // AT_USE_JITERATOR()
360