xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/JitLoops.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/jit_macros.h>
4 
5 #if AT_USE_JITERATOR()
6 
7 #include <ATen/cuda/CUDAConfig.h>
8 
9 #include <ATen/OpMathType.h>
10 #include <ATen/TensorIterator.h>
11 #include <ATen/native/TensorIteratorDynamicCasting.h>
12 
13 #include <ATen/native/cuda/MemoryAccess.cuh>
14 
15 #include <ATen/native/cuda/CUDAJitLoops.cuh>
16 
17 namespace at {
18 namespace native {
19 
20 /* Note [Jiterator]
21 The "jiterator" simply just-in-time compiles the same kernels that
22 Loops.cuh (and CUDALoops.cuh) usually build. This reduces build time,
23 build size, and initial CUDA context size.
24 
25 By default on non-Windows systems, it also caches compiled kernels in ~/.cache/torch/kernels.
26 This behavior is controlled with two environment variables:
27   - USE_PYTORCH_KERNEL_CACHE, if set to zero then this will disable all cache use
28   - PYTORCH_KERNEL_CACHE_PATH, if set specifies the folder to use for cached kernels
29 
30 The jiterator currently has some limitations, however. It cannot:
31   - handle math on complex datatypes
32   - handle kernels with scalar parameters
33 
34 These improvements will likely come soon.
35 
36 For examples of how to use the jiterator see the i1 and gcd kernel
37 implementations, which pass jittable strings implementing their
38 operations instead of the typical CUDA functors.
39 
40 To pass a runtime argument (similar to lambda captures in non-JIT kernels),
41 we need to pass to additional arguments to `jitted_gpu_kernel` by value.
42 Currently only primitive C++ types used for computation are valid.
43 The order of these extra arguments should be same as the order they appear
44 in kernel's function signature. (look at polygamma for example)
45 
46 NOTE: One big restriction being that these arguments should be after the
47 arguments provided by TensorIterator. Eg. While capturing `n`, where
48 `scalar_t x` and `scalar_t y` are provided by TensorIterator,
49 * foo(scalar_t x, scalar_t y, int n) works!
50 * foo(int n, scalar_t x, scalar_y) doesn't work
51 * foo(scalar_t x, int n, scalar_y) doesn't work
52 
53 */
54 
55 // Entrypoint for jitted GPU kernels.
56 // Only handles elementwise unary and binary kernels with a
57 //   common dtype and a single output.
58 // NOTE: this assumes the op's iterator has a common_dtype.
59 // NOTE: We use std::tuple instead of parameter pack
60 //  for `extra_args` due to following
61 // bug on older versions of clang
62 // https://bugs.llvm.org/show_bug.cgi?id=23029
63 template <
64     char const* name,
65     typename return_type,
66     typename f_inputs_type,
67     int arity,
68     typename... Args>
jitted_gpu_kernel(TensorIteratorBase & iter,const std::string & f,at::cuda::jit::BinaryFuncVariant scalar_pos=at::cuda::jit::BinaryFuncVariant::NoScalar,at::opmath_type<f_inputs_type> scalar_val=0,std::tuple<Args...> extra_args=std::make_tuple ())69 void jitted_gpu_kernel(
70     TensorIteratorBase& iter,
71     const std::string& f,
72     at::cuda::jit::BinaryFuncVariant scalar_pos =
73         at::cuda::jit::BinaryFuncVariant::NoScalar,
74     at::opmath_type<f_inputs_type> scalar_val = 0,
75     std::tuple<Args...> extra_args = std::make_tuple()) {
76   // TODO: much of preamble is common to both jitted_gpu_kernel and gpu_kernel
77   //   Maybe it could be refactored?
78   for (int arg = 0; arg < iter.ntensors(); arg++) {
79     TORCH_INTERNAL_ASSERT(
80       iter.device(arg).is_cuda(),
81       "argument ", arg, ": expected a CUDA device but found ", iter.device(arg));
82   }
83 
84   if (iter.numel() == 0) {
85     return;
86   }
87 
88   if (!iter.can_use_32bit_indexing()) {
89     for (auto& sub_iter : iter.with_32bit_indexing()) {
90       jitted_gpu_kernel<name, return_type, f_inputs_type, arity>(
91           sub_iter, f, scalar_pos, scalar_val, extra_args);
92     }
93 
94     return;
95   }
96 
97   // Computes if dynamic casting is needed
98   // Dynamic casting is needed if an input's dtype differs from the common dtype
99   //   or if the result dtype differs from the output's dtype
100   // Note: this is intentionally divergent from calling needs_dynamic_casting,
101   //   which is more general and inspects a lambda to determine if dynamic
102   //   casting is needed.
103   bool needs_dynamic_casting = false;
104 
105   // Checks output
106   const ScalarType return_scalar_type = c10::CppTypeToScalarType<return_type>::value;
107   const auto dtype0 = iter.dtype(0);
108   if (dtype0 != return_scalar_type) {
109     needs_dynamic_casting = true;
110   }
111 
112   // Checks input(s)
113   const ScalarType inputs_scalar_type = c10::CppTypeToScalarType<f_inputs_type>::value;
114   for (auto i = decltype(arity){1}; i < (arity + 1); ++i) {
115     const auto dtypei = iter.dtype(i);
116     if (dtypei != inputs_scalar_type) {
117       needs_dynamic_casting = true;
118       break;
119     }
120   }
121   if (scalar_pos == at::cuda::jit::BinaryFuncVariant::NoScalar) {
122     // NOTE: With `scalar_pos=NoScalar`,`scalar_val` is not used
123     // for computation in the generated code and hence we pass a dummy
124     // value of `0`.
125     jitted_gpu_kernel_impl<
126         /*name*/ name,
127         /*return_type=*/return_type,
128         /*f_inputs_type=*/f_inputs_type,
129         arity,
130         at::cuda::jit::BinaryFuncVariant::NoScalar>(
131         iter, f, needs_dynamic_casting, /*scalar_val=*/scalar_val, extra_args);
132   } else if (scalar_pos == at::cuda::jit::BinaryFuncVariant::RhsScalar) {
133     jitted_gpu_kernel_impl<
134         /*name*/ name,
135         /*return_type=*/return_type,
136         /*f_inputs_type=*/f_inputs_type,
137         arity,
138         at::cuda::jit::BinaryFuncVariant::RhsScalar>(
139         iter,
140         f,
141         needs_dynamic_casting,
142         scalar_val,
143         extra_args);
144 
145   } else {
146     jitted_gpu_kernel_impl<
147         /*name*/ name,
148         /*return_type=*/return_type,
149         /*f_inputs_type=*/f_inputs_type,
150         arity,
151         at::cuda::jit::BinaryFuncVariant::LhsScalar>(
152         iter,
153         f,
154         needs_dynamic_casting,
155         scalar_val,
156         extra_args);
157   }
158 }
159 
160 // TODO: support runtime state capture similar to `jitted_gpu_kernel`.
161 template <char const *name, typename return_type, typename f_inputs_type>
opmath_jitted_gpu_kernel_with_scalars(TensorIteratorBase & iter,const std::string & f)162 void opmath_jitted_gpu_kernel_with_scalars(TensorIteratorBase& iter, const std::string& f) {
163   TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
164   //currently jiterator only handles binary functions where both inputs are of the same type (f_inputs_type)
165   using opmath_t = at::opmath_type<f_inputs_type>;
166   if (iter.is_cpu_scalar(1)) {
167     auto scalar_val = iter.scalar_value<opmath_t>(1);
168     iter.remove_operand(1);
169     // TODO: When all kernels that use gpu_kernel_with_scalars are
170     // ported to structured, this device guard can be deleted.  This
171     // works around incorrect device guard generation for pre-structured
172     // kernels device guards, but structured kernels do it right and
173     // we can assume the device is already set correctly
174     const OptionalDeviceGuard device_guard(iter.device(1));
175     jitted_gpu_kernel<name, return_type, f_inputs_type, 1>(iter, f, at::cuda::jit::BinaryFuncVariant::LhsScalar, scalar_val);
176   } else if (iter.is_cpu_scalar(2)) {
177     auto scalar_val = iter.scalar_value<opmath_t>(2);
178     iter.remove_operand(2);
179     jitted_gpu_kernel<name, return_type, f_inputs_type, 1>(iter, f, at::cuda::jit::BinaryFuncVariant::RhsScalar, scalar_val);
180   } else {
181     jitted_gpu_kernel<name, return_type, f_inputs_type, 2>(iter, f);
182   }
183 }
184 
185 }}  // at::native
186 
187 #endif // AT_USE_JITERATOR()
188