1 #include <ATen/jit_macros.h>
2
3 #if AT_USE_JITERATOR()
4
5 #include <c10/cuda/CUDAGuard.h>
6 #include <ATen/cuda/jiterator.h>
7 #include <ATen/cuda/jiterator_impl.h>
8
9 #include <iostream>
10 #include <utility>
11 #include <chrono>
12 namespace at {
13 namespace native {
14
launch_jitted_vectorized_kernel_dynamic(const std::string & name,TensorIteratorBase & iter,DeviceIndex dev_idx,int64_t N,const std::string & f,void * data_ptr,const c10::SmallVector<at::Scalar> & extra_args,bool return_by_ref)15 static inline void launch_jitted_vectorized_kernel_dynamic(
16 const std::string& name, TensorIteratorBase& iter,
17 DeviceIndex dev_idx, int64_t N, const std::string& f, void* data_ptr,
18 const c10::SmallVector<at::Scalar>& extra_args, bool return_by_ref) {
19 TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
20 // N is still int64_t for the computation, but it's always safe to cast result to int
21 const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
22
23 const int vec_size = jitted_can_vectorize_up_to(iter);
24 bool vectorized = vec_size > 1;
25
26 // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
27 // fn_ptr is set to the appropriate function based on the vec size and GPU used
28 // TODO: Memory use can probably be optimized by re-using kernels across GPUs with
29 // the same compute capability
30
31 int nInputs = iter.ninputs();
32 int nOutputs = iter.noutputs();
33 const at::ScalarType common_dtype = iter.common_dtype();
34 std::string f_inputs_type_str = at::cuda::jit::typeName(common_dtype);
35 std::string compute_type_str = at::cuda::jit::typeName(toOpMathType(common_dtype));
36 std::string result_type_str = at::cuda::jit::typeName(common_dtype);
37 c10::SmallVector<std::string> extra_args_types = get_extra_args_typenames(extra_args);
38
39 // The cache key includes all the parameters to generate_code + vec_size + dev_idx
40 std::stringstream ss;
41 ss << nInputs << "_" << nOutputs << f;
42 ss << f_inputs_type_str << compute_type_str << result_type_str;
43 ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar);
44 ss << extra_args_types;
45 ss << vec_size;
46 // DeviceIndex, e.g. int8_t, is not treated as a number by the stream, cast to int as a workaround
47 ss << static_cast<int>(dev_idx);
48 const std::string cache_key = ss.str();
49
50 static std::mutex _jiterator_mutex;
51 static std::unordered_map<std::string, at::cuda::jit::NvrtcFunction> fns;
52 at::cuda::jit::NvrtcFunction* fn_ptr = &fns[cache_key];
53
54 if (!fn_ptr->function) {
55 const std::lock_guard<std::mutex> lock{_jiterator_mutex};
56 if (!fn_ptr->function) { // cache miss!
57 // Generates program
58 auto code = at::cuda::jit::generate_code(nInputs, nOutputs, f, name,
59 f_inputs_type_str, compute_type_str, result_type_str,
60 /*contiguous=*/true, /*dynamic_casting=*/false,
61 at::cuda::jit::BinaryFuncVariant::NoScalar,
62 extra_args_types,
63 vectorized, vec_size,
64 return_by_ref);
65 std::string kernel_name = vectorized ? name + "_vectorized" + std::to_string(vec_size) : name;
66 // Acquires the program
67 *fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name);
68 }
69 }
70
71 // size of `extra_args` is unknown at compile-time
72 auto extra_args_size = extra_args.size();
73
74 float scalar_val = 0;
75
76 if (vectorized) {
77 // pack args for kernel launch
78 constexpr int kernel_args = 3;
79 auto args = std::make_unique<void*[]>(kernel_args + extra_args_size);
80 args[0] = static_cast<void*>(&N);
81 args[1] = data_ptr;
82 args[2] = static_cast<void*>(&scalar_val);
83
84 for (const auto i : c10::irange(extra_args_size)) {
85 // since 3 slots are already filled in `args`
86 args[i + 3] = const_cast<void*>(extra_args[i].data_ptr());
87 }
88 at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args.get(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
89 } else {
90 TrivialOffsetCalculatorVariant input_offset_calculator(iter.ninputs());
91 void* ic_ptr = input_offset_calculator.data_ptr();
92 TrivialOffsetCalculatorVariant output_offset_calculator(iter.noutputs());
93 void* oc_ptr = output_offset_calculator.data_ptr();
94
95 auto l = memory::LoadWithoutCast();
96 auto s = memory::StoreWithoutCast();
97
98 // pack args for kernel launch
99 constexpr int kernel_args = 7;
100 auto args = std::make_unique<void*[]>(kernel_args + extra_args_size);
101 args[0] = static_cast<void*>(&N);
102 args[1] = data_ptr;
103 args[2] = ic_ptr;
104 args[3] = oc_ptr;
105 args[4] = static_cast<void*>(&l);
106 args[5] = static_cast<void*>(&s);
107 args[6] = static_cast<void*>(&scalar_val);
108
109 for (const auto i : c10::irange(extra_args_size)) {
110 // since 7 slots are already filled in `args`
111 args[i + 7] = const_cast<void*>(extra_args[i].data_ptr());
112 }
113
114 at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args.get(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
115 }
116 }
117
launch_jitted_unrolled_kernel_dynamic(const std::string & name,TensorIteratorBase & iter,DeviceIndex dev_idx,int64_t N,const std::string & f,void * data_ptr,void * ic_ptr,void * oc_ptr,void * l_ptr,void * s_ptr,bool contiguous,bool dynamic_casting,const c10::SmallVector<at::Scalar> & extra_args,bool return_by_ref)118 static inline void launch_jitted_unrolled_kernel_dynamic(
119 const std::string& name, TensorIteratorBase& iter,
120 DeviceIndex dev_idx, int64_t N, const std::string& f, void* data_ptr,
121 void* ic_ptr, void* oc_ptr, void* l_ptr, void* s_ptr, bool contiguous, bool dynamic_casting,
122 const c10::SmallVector<at::Scalar>& extra_args, bool return_by_ref) {
123
124 TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
125 //casting result to int is always safe, intermediate is int64 and won't overflow
126 const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
127
128 int nInputs = iter.ninputs();
129 int nOutputs = iter.noutputs();
130 const at::ScalarType common_dtype = iter.common_dtype();
131 std::string f_inputs_type_str = at::cuda::jit::typeName(common_dtype);
132 std::string compute_type_str = at::cuda::jit::typeName(toOpMathType(common_dtype));
133 std::string result_type_str = at::cuda::jit::typeName(common_dtype);
134 c10::SmallVector<std::string> extra_args_types = get_extra_args_typenames(extra_args);
135
136 // The cache key includes all the parameters to generate_code + dev_idx
137 std::stringstream ss;
138 ss << nInputs << "_" << nOutputs << f;
139 ss << f_inputs_type_str << compute_type_str << result_type_str;
140 ss << contiguous << dynamic_casting;
141 ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar);
142 ss << extra_args_types;
143 ss << dev_idx;
144 const std::string cache_key = ss.str();
145
146 static std::mutex _jiterator_mutex;
147 static std::unordered_map<std::string, at::cuda::jit::NvrtcFunction> fns;
148
149 at::cuda::jit::NvrtcFunction* fn_ptr = &fns[cache_key];
150 if (!fn_ptr->function) {
151 const std::lock_guard<std::mutex> lock{_jiterator_mutex};
152 if (!fn_ptr->function) {
153 auto code = at::cuda::jit::generate_code(nInputs, nOutputs, f, name,
154 f_inputs_type_str, compute_type_str, result_type_str,
155 contiguous, dynamic_casting,
156 at::cuda::jit::BinaryFuncVariant::NoScalar,
157 extra_args_types, /*vectorized*/false, /*vec_size*/0, return_by_ref);
158 *fn_ptr = at::cuda::jit::jit_pwise_function(code, name);
159 }
160 }
161
162 float scalar_val = 0;
163
164 // pack args for kernel launch
165 constexpr int kernel_args = 7;
166 auto extra_args_size = extra_args.size();
167 auto args = std::make_unique<void*[]>(kernel_args + extra_args_size);
168 args[0] = static_cast<void*>(&N);
169 args[1] = data_ptr;
170 args[2] = ic_ptr;
171 args[3] = oc_ptr;
172 args[4] = l_ptr;
173 args[5] = s_ptr;
174 args[6] = static_cast<void*>(&scalar_val);
175
176 for (const auto i : c10::irange(extra_args_size)) {
177 // since 7 slots are already filled in `args`
178 args[i + 7] = const_cast<void*>(extra_args[i].data_ptr());
179 }
180
181 at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args.get(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
182 }
183
jitted_gpu_kernel_dynamic_impl(const std::string & kernel_name,TensorIteratorBase & iter,const std::string & f,const bool dynamic_casting,const c10::SmallVector<at::Scalar> & extra_args,bool return_by_ref)184 void jitted_gpu_kernel_dynamic_impl(
185 const std::string& kernel_name,
186 TensorIteratorBase& iter,
187 const std::string& f,
188 const bool dynamic_casting,
189 const c10::SmallVector<at::Scalar>& extra_args,
190 bool return_by_ref) {
191
192 TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
193 TORCH_INTERNAL_ASSERT(iter.noutputs() <= 8);
194 TORCH_INTERNAL_ASSERT(iter.ninputs() <= 8);
195
196 ArrayVariant data(iter);
197 void* data_ptr = data.data_ptr();
198
199 int64_t numel = iter.numel();
200 bool contiguous = iter.is_contiguous();
201
202 // Decides which of 4 kernel types to launch
203 // Variations are:
204 // - Case 1: no dynamic casting and contiguous
205 // - Case 2: no dynamic casting and noncontiguous
206 // - Case 3: dynamic casting and contiguous
207 // - Case 4: dynamic casting and noncontiguous
208 // These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl
209
210 if (!dynamic_casting) {
211 if (contiguous) {
212 // Case 1: no dynamic casting and contiguous
213 launch_jitted_vectorized_kernel_dynamic(kernel_name, iter,
214 iter.device().index(), numel, f, data_ptr, extra_args, return_by_ref);
215 return;
216 }
217
218 // Case 2: no dynamic casting and noncontiguous
219 OffsetCalculatorVariant</*is_input=*/true> input_offset_calculator(iter);
220 void* ic_ptr = input_offset_calculator.data_ptr();
221 OffsetCalculatorVariant</*is_input=*/false> output_offset_calculator(iter);
222 void* oc_ptr = output_offset_calculator.data_ptr();
223
224 auto loader = memory::LoadWithoutCast();
225 auto storer = memory::StoreWithoutCast();
226 void* l_ptr = static_cast<void*>(&loader);
227 void* s_ptr = static_cast<void*>(&storer);
228
229 launch_jitted_unrolled_kernel_dynamic(
230 kernel_name, iter, iter.device().index(), numel, f, data_ptr,
231 ic_ptr, oc_ptr, l_ptr, s_ptr, contiguous, dynamic_casting, extra_args, return_by_ref);
232
233 return;
234 }
235
236 // Cases 3 and 4 are handled below
237 // Both require construction of one or more storers and loaders
238
239 // Creates load casts from inputs (note offset indexing into the iterators noutpus...n tensors)
240 LoadWithCastVariant loader(iter);
241 void* l_ptr = loader.data_ptr();
242
243 // Creates store cast to output (the 0...noutpus-1 tensor in TensorIterator)
244 StoreWithCastVariant storer(iter);
245 void* s_ptr = storer.data_ptr();
246
247 if (contiguous) {
248 // Case 3: dynamic casting and contiguous
249 TrivialOffsetCalculatorVariant input_offset_calculator(iter.ninputs());
250 void* ic_ptr = input_offset_calculator.data_ptr();
251 TrivialOffsetCalculatorVariant output_offset_calculator(iter.noutputs());
252 void* oc_ptr = output_offset_calculator.data_ptr();
253
254 launch_jitted_unrolled_kernel_dynamic(
255 kernel_name, iter, iter.device().index(), numel, f, data_ptr,
256 ic_ptr, oc_ptr, l_ptr, s_ptr, contiguous, dynamic_casting, extra_args, return_by_ref);
257 return;
258 }
259
260 // Case 4: dynamic casting and noncontiguous
261 OffsetCalculatorVariant</*is_input=*/true> input_offset_calculator(iter);
262 void* ic_ptr = input_offset_calculator.data_ptr();
263 OffsetCalculatorVariant</*is_input=*/false> output_offset_calculator(iter);
264 void* oc_ptr = output_offset_calculator.data_ptr();
265
266 launch_jitted_unrolled_kernel_dynamic(
267 kernel_name, iter, iter.device().index(), numel, f, data_ptr,
268 ic_ptr, oc_ptr, l_ptr, s_ptr, contiguous, dynamic_casting, extra_args, return_by_ref);
269 }
270
271 // Entrypoint for dynamic version of jitted GPU kernels, which accepts dynamic number of inputs
272 // and arbitrary types of input and extra args. This dynamic version is needed for jiterator with python interface,
273 // since the kernel definition is unknown at the compilation time.
274 // Similarly, launch_jitted_vectorized_kernel_dynamic and launch_jitted_unrolled_kernel_dynamic are created
275 // to handle arbitrary functions defined in python user code.
276 // For templated version, see note [Jiterator] in JitLoops.cuh for more details
jitted_gpu_kernel_dynamic(const std::string & kernel_name,TensorIteratorBase & iter,const std::string & f,const c10::SmallVector<at::Scalar> & extra_args,bool return_by_ref)277 void jitted_gpu_kernel_dynamic(
278 const std::string& kernel_name,
279 TensorIteratorBase& iter,
280 const std::string& f,
281 const c10::SmallVector<at::Scalar>& extra_args,
282 bool return_by_ref) {
283
284 // TODO: much of preamble is common to both jitted_gpu_kernel and gpu_kernel
285 // Maybe it could be refactored?
286 for (int arg = 0; arg < iter.ntensors(); arg++) {
287 TORCH_INTERNAL_ASSERT(
288 iter.device(arg).is_cuda(),
289 "argument ", arg, ": expected a CUDA device but found ", iter.device(arg));
290 }
291
292 if (iter.numel() == 0) {
293 return;
294 }
295
296 if (!iter.can_use_32bit_indexing()) {
297 for (auto& sub_iter : iter.with_32bit_indexing()) {
298 jitted_gpu_kernel_dynamic(kernel_name, sub_iter, f, extra_args, return_by_ref);
299 }
300 return;
301 }
302
303 // Computes if dynamic casting is needed
304 // Dynamic casting is needed if an input's or output's dtype differs from the common dtype
305 bool needs_dynamic_casting = false;
306 const at::ScalarType common_dtype = iter.common_dtype();
307 for (auto i = 0; i < iter.ntensors(); ++i) {
308 if (iter.dtype(i) != common_dtype) {
309 needs_dynamic_casting = true;
310 break;
311 }
312 }
313
314 jitted_gpu_kernel_dynamic_impl(kernel_name, iter, f, needs_dynamic_casting, extra_args, return_by_ref);
315 }
316
317 } // namespace native
318
319 namespace cuda {
320
CompileAndLaunchKernel(const std::string & code_string,const std::string & kernel_name,const int num_outputs,const c10::SmallVector<at::Tensor> & tensors,const c10::SmallVector<at::Scalar> & extra_args,bool return_by_ref)321 c10::SmallVector<at::Tensor> CompileAndLaunchKernel(
322 const std::string& code_string,
323 const std::string& kernel_name,
324 const int num_outputs,
325 const c10::SmallVector<at::Tensor>& tensors,
326 const c10::SmallVector<at::Scalar>& extra_args,
327 bool return_by_ref) {
328
329 c10::SmallVector<at::Tensor> outs(num_outputs);
330 TensorIteratorConfig config;
331 config
332 .set_check_mem_overlap(true)
333 .allow_cpu_scalars(false)
334 .promote_inputs_to_common_dtype(true)
335 .cast_common_dtype_to_outputs(true)
336 .enforce_safe_casting_to_output(true)
337 .check_all_same_device(true);
338 for (int i = 0; i < num_outputs; ++i) {
339 config.add_owned_output(outs[i]);
340 }
341 for (const auto& t: tensors) {
342 config.add_const_input(t);
343 }
344 TensorIterator iter = config.build();
345
346 CUDAGuard guard(iter.device());
347 at::native::jitted_gpu_kernel_dynamic(kernel_name, iter, code_string, extra_args, return_by_ref);
348
349 c10::SmallVector<at::Tensor> outputs;
350 for (int i = 0; i < num_outputs; ++i) {
351 outputs.emplace_back(iter.output(i));
352 }
353
354 return outputs;
355 }
356
357 }} // namespace at::cuda
358
359 #endif // AT_USE_JITERATOR()
360