xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/jit_utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <string>
4 #include <sstream>
5 #include <unordered_map>
6 #include <vector>
7 
8 #include <c10/util/irange.h>
9 #include <ATen/jit_macros.h>
10 #include <ATen/cuda/detail/LazyNVRTC.h>
11 
12 namespace at { namespace cuda { namespace jit {
13 
14 enum class BinaryFuncVariant {NoScalar, RhsScalar, LhsScalar};
15 
16 struct NvrtcFunction {
17   CUmodule module = CUmodule();
18   CUfunction function = nullptr;
19 };
20 
21 struct KernelDescriptor {
22   std::string name;
23   std::string f;
24   c10::ScalarType f_inputs_type;
25   c10::ScalarType result_type;
26   c10::SmallVector<c10::ScalarType> extra_args_types;
27   int nInputs, nOutputs;
28 };
29 
30 // Helper function to return a vector<string>
31 // corresponding to the type of the arguments in parameter pack.
32 template <typename... Args>
get_extra_args_types()33 c10::SmallVector<at::ScalarType> get_extra_args_types() {
34   return {c10::CppTypeToScalarType<Args>::value ...};
35 }
36 
37 template <
38   typename result_type,
39   typename f_inputs_type,
40   typename... ExtraArgs>
make_kernel_descriptor(std::string name,std::string f,int nInputs,int nOutputs)41 KernelDescriptor make_kernel_descriptor(
42     std::string name,
43     std::string f,
44     int nInputs,
45     int nOutputs) {
46   KernelDescriptor ret;
47   ret.name = std::move(name);
48   ret.f = std::move(f);
49   ret.f_inputs_type = c10::CppTypeToScalarType<f_inputs_type>::value;
50   ret.result_type = c10::CppTypeToScalarType<result_type>::value;
51   ret.extra_args_types = get_extra_args_types<ExtraArgs...>();
52   ret.nInputs = nInputs;
53   ret.nOutputs = nOutputs;
54   return ret;
55 }
56 
can_vectorize_up_to(size_t default_alignment,void * pointer)57 inline int can_vectorize_up_to(size_t default_alignment, void *pointer) {
58   auto ip = reinterpret_cast<uintptr_t>(pointer);
59   if (ip % (4 * default_alignment) == 0) {
60     return 4;
61   }
62   if (ip % (2 * default_alignment) == 0) {
63     return 2;
64   }
65   return 1;
66 }
67 
can_vectorize_up_to(const KernelDescriptor & desc,c10::ArrayRef<char * > pointers)68 inline int can_vectorize_up_to(const KernelDescriptor &desc, c10::ArrayRef<char*> pointers) {
69   TORCH_INTERNAL_ASSERT(desc.nOutputs == 1);
70   TORCH_INTERNAL_ASSERT(static_cast<int64_t>(pointers.size()) == 1 + desc.nInputs);
71 
72   // Deals with output
73   auto result_size = c10::scalarTypeToTypeMeta(desc.result_type).itemsize();
74   int result = can_vectorize_up_to(result_size, pointers[0]);
75 
76   // Incorporates input(s)
77   auto input_size = c10::scalarTypeToTypeMeta(desc.f_inputs_type).itemsize();
78   for (auto i : c10::irange(1, pointers.size())) {
79     result = std::min(result, can_vectorize_up_to(input_size, pointers[i]));
80   }
81 
82   return result;
83 }
84 
85 std::string generate_code(
86     int nInputs,
87     int nOutputs,
88     const std::string& func,
89     const std::string& name,
90     const std::string& f_input_type,
91     const std::string& compute_type,
92     const std::string& result_type,
93     bool contiguous,
94     bool dynamic_casting,
95     BinaryFuncVariant scalar_pos,
96     c10::SmallVector<std::string>& extra_args_typenames,
97     bool vectorized=false,
98     int vec_size=0,
99     bool return_by_ref=false);
100 
101 std::string generate_code(
102     const KernelDescriptor &desc,
103     bool contiguous,
104     bool dynamic_casting,
105     BinaryFuncVariant scalar_pos,
106     bool vectorized=false,
107     int vec_size=0,
108     bool return_by_ref=false);
109 
110 std::string generate_reduction_code(
111     int nOutputs,
112     const std::string& func,
113     const std::string& name,
114     const int vt0,
115     const std::string& f_inputs_type,
116     const std::string& reduction_accum_type,
117     const std::string& result_type,
118     bool contiguous,
119     bool vectorized,
120     int vec_size,
121     int max_threads_codegen);
122 
123 std::string generate_reduction_code(
124     const KernelDescriptor &desc,
125     const int vt0,
126     bool contiguous,
127     bool vectorized,
128     int vec_size,
129     int max_threads_codegen);
130 
131 NvrtcFunction jit_pwise_function(
132     const std::string& code,
133     const std::string& kernel_name);
134 
135 void launch_jitted_pwise_function(
136     NvrtcFunction function,
137     void* args[],
138     const dim3 nBlocks,
139     const dim3 kBlockSize,
140     const int smem=0);
141 
142 template <typename T>
143 struct delayed_false : std::false_type {
144 };
145 
146 // Defines type names
147 // NOTE: General case is instantiated only for invalid types.
148 // All the valid types have specialization using the TYPE_NAME_FN
149 // macro below.
150 template <typename T>
typeName()151 inline std::string typeName() {
152   // we can't use static_assert(false) directly as the
153   // program will be not compiled even if the template is not
154   // instantiated, so we use `delayed_false`
155   // to make sure compiler doesn't eagerly raise
156   // fail this assertion.
157   static_assert(delayed_false<T>::value, "invalid type for jiterator");
158   return "void";
159 }
160 
161 #define TYPE_NAME_FN(ctype, name) \
162 template <> inline std::string typeName<ctype>(){ \
163     return std::string(#ctype);    \
164 }
165 
AT_FORALL_SCALAR_TYPES(TYPE_NAME_FN)166 AT_FORALL_SCALAR_TYPES(TYPE_NAME_FN)
167 #undef TYPE_NAME_FN
168 // JIT uses std::complex directly, because nvRTC compile programs
169 // with -default-device, so there is no such issue like:
170 //   "std::sin(complex) is __host__ only"
171 template <> inline std::string typeName<bool>(){
172     return "bool";
173 }
174 template <> inline std::string typeName<c10::complex<at::Half>>(){
175     return "std::complex<at::Half>";
176 }
177 template <> inline std::string typeName<c10::complex<float>>(){
178     return "std::complex<float>";
179 }
180 template <> inline std::string typeName<c10::complex<double>>(){
181     return "std::complex<double>";
182 }
183 template <> inline std::string typeName<at::Half>(){
184     return "at::Half";
185 }
186 template <> inline std::string typeName<at::BFloat16>(){
187     return "at::BFloat16";
188 }
189 template <> inline std::string typeName<at::Float8_e5m2>(){
190     return "at::Float8_e5m2";
191 }
192 template <> inline std::string typeName<at::Float8_e4m3fn>(){
193     return "at::Float8_e4m3fn";
194 }
195 template <> inline std::string typeName<at::Float8_e5m2fnuz>() {
196     return "at::Float8_e5m2fnuz";
197 }
198 template <> inline std::string typeName<at::Float8_e4m3fnuz>() {
199     return "at::Float8_e4m3fnuz";
200 }
201 
202 #define TYPE_NAME_CASE(ctype, scalartype)                    \
203   case ScalarType::scalartype:  return typeName<ctype>();
typeName(ScalarType t)204 inline std::string typeName(ScalarType t) {
205     switch (t) {
206       AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(TYPE_NAME_CASE)
207       default:
208           TORCH_CHECK(false, "invalid type for jiterator");
209     }
210 }
211 #undef TYPE_NAME_CASE
212 
213 TORCH_CUDA_CPP_API void initializeCudaContext();
214 
215 }}}  // namespace at::cuda::jit
216