1 #pragma once
2
3 #include <string>
4 #include <sstream>
5 #include <unordered_map>
6 #include <vector>
7
8 #include <c10/util/irange.h>
9 #include <ATen/jit_macros.h>
10 #include <ATen/cuda/detail/LazyNVRTC.h>
11
12 namespace at { namespace cuda { namespace jit {
13
14 enum class BinaryFuncVariant {NoScalar, RhsScalar, LhsScalar};
15
16 struct NvrtcFunction {
17 CUmodule module = CUmodule();
18 CUfunction function = nullptr;
19 };
20
21 struct KernelDescriptor {
22 std::string name;
23 std::string f;
24 c10::ScalarType f_inputs_type;
25 c10::ScalarType result_type;
26 c10::SmallVector<c10::ScalarType> extra_args_types;
27 int nInputs, nOutputs;
28 };
29
30 // Helper function to return a vector<string>
31 // corresponding to the type of the arguments in parameter pack.
32 template <typename... Args>
get_extra_args_types()33 c10::SmallVector<at::ScalarType> get_extra_args_types() {
34 return {c10::CppTypeToScalarType<Args>::value ...};
35 }
36
37 template <
38 typename result_type,
39 typename f_inputs_type,
40 typename... ExtraArgs>
make_kernel_descriptor(std::string name,std::string f,int nInputs,int nOutputs)41 KernelDescriptor make_kernel_descriptor(
42 std::string name,
43 std::string f,
44 int nInputs,
45 int nOutputs) {
46 KernelDescriptor ret;
47 ret.name = std::move(name);
48 ret.f = std::move(f);
49 ret.f_inputs_type = c10::CppTypeToScalarType<f_inputs_type>::value;
50 ret.result_type = c10::CppTypeToScalarType<result_type>::value;
51 ret.extra_args_types = get_extra_args_types<ExtraArgs...>();
52 ret.nInputs = nInputs;
53 ret.nOutputs = nOutputs;
54 return ret;
55 }
56
can_vectorize_up_to(size_t default_alignment,void * pointer)57 inline int can_vectorize_up_to(size_t default_alignment, void *pointer) {
58 auto ip = reinterpret_cast<uintptr_t>(pointer);
59 if (ip % (4 * default_alignment) == 0) {
60 return 4;
61 }
62 if (ip % (2 * default_alignment) == 0) {
63 return 2;
64 }
65 return 1;
66 }
67
can_vectorize_up_to(const KernelDescriptor & desc,c10::ArrayRef<char * > pointers)68 inline int can_vectorize_up_to(const KernelDescriptor &desc, c10::ArrayRef<char*> pointers) {
69 TORCH_INTERNAL_ASSERT(desc.nOutputs == 1);
70 TORCH_INTERNAL_ASSERT(static_cast<int64_t>(pointers.size()) == 1 + desc.nInputs);
71
72 // Deals with output
73 auto result_size = c10::scalarTypeToTypeMeta(desc.result_type).itemsize();
74 int result = can_vectorize_up_to(result_size, pointers[0]);
75
76 // Incorporates input(s)
77 auto input_size = c10::scalarTypeToTypeMeta(desc.f_inputs_type).itemsize();
78 for (auto i : c10::irange(1, pointers.size())) {
79 result = std::min(result, can_vectorize_up_to(input_size, pointers[i]));
80 }
81
82 return result;
83 }
84
85 std::string generate_code(
86 int nInputs,
87 int nOutputs,
88 const std::string& func,
89 const std::string& name,
90 const std::string& f_input_type,
91 const std::string& compute_type,
92 const std::string& result_type,
93 bool contiguous,
94 bool dynamic_casting,
95 BinaryFuncVariant scalar_pos,
96 c10::SmallVector<std::string>& extra_args_typenames,
97 bool vectorized=false,
98 int vec_size=0,
99 bool return_by_ref=false);
100
101 std::string generate_code(
102 const KernelDescriptor &desc,
103 bool contiguous,
104 bool dynamic_casting,
105 BinaryFuncVariant scalar_pos,
106 bool vectorized=false,
107 int vec_size=0,
108 bool return_by_ref=false);
109
110 std::string generate_reduction_code(
111 int nOutputs,
112 const std::string& func,
113 const std::string& name,
114 const int vt0,
115 const std::string& f_inputs_type,
116 const std::string& reduction_accum_type,
117 const std::string& result_type,
118 bool contiguous,
119 bool vectorized,
120 int vec_size,
121 int max_threads_codegen);
122
123 std::string generate_reduction_code(
124 const KernelDescriptor &desc,
125 const int vt0,
126 bool contiguous,
127 bool vectorized,
128 int vec_size,
129 int max_threads_codegen);
130
131 NvrtcFunction jit_pwise_function(
132 const std::string& code,
133 const std::string& kernel_name);
134
135 void launch_jitted_pwise_function(
136 NvrtcFunction function,
137 void* args[],
138 const dim3 nBlocks,
139 const dim3 kBlockSize,
140 const int smem=0);
141
142 template <typename T>
143 struct delayed_false : std::false_type {
144 };
145
146 // Defines type names
147 // NOTE: General case is instantiated only for invalid types.
148 // All the valid types have specialization using the TYPE_NAME_FN
149 // macro below.
150 template <typename T>
typeName()151 inline std::string typeName() {
152 // we can't use static_assert(false) directly as the
153 // program will be not compiled even if the template is not
154 // instantiated, so we use `delayed_false`
155 // to make sure compiler doesn't eagerly raise
156 // fail this assertion.
157 static_assert(delayed_false<T>::value, "invalid type for jiterator");
158 return "void";
159 }
160
161 #define TYPE_NAME_FN(ctype, name) \
162 template <> inline std::string typeName<ctype>(){ \
163 return std::string(#ctype); \
164 }
165
AT_FORALL_SCALAR_TYPES(TYPE_NAME_FN)166 AT_FORALL_SCALAR_TYPES(TYPE_NAME_FN)
167 #undef TYPE_NAME_FN
168 // JIT uses std::complex directly, because nvRTC compile programs
169 // with -default-device, so there is no such issue like:
170 // "std::sin(complex) is __host__ only"
171 template <> inline std::string typeName<bool>(){
172 return "bool";
173 }
174 template <> inline std::string typeName<c10::complex<at::Half>>(){
175 return "std::complex<at::Half>";
176 }
177 template <> inline std::string typeName<c10::complex<float>>(){
178 return "std::complex<float>";
179 }
180 template <> inline std::string typeName<c10::complex<double>>(){
181 return "std::complex<double>";
182 }
183 template <> inline std::string typeName<at::Half>(){
184 return "at::Half";
185 }
186 template <> inline std::string typeName<at::BFloat16>(){
187 return "at::BFloat16";
188 }
189 template <> inline std::string typeName<at::Float8_e5m2>(){
190 return "at::Float8_e5m2";
191 }
192 template <> inline std::string typeName<at::Float8_e4m3fn>(){
193 return "at::Float8_e4m3fn";
194 }
195 template <> inline std::string typeName<at::Float8_e5m2fnuz>() {
196 return "at::Float8_e5m2fnuz";
197 }
198 template <> inline std::string typeName<at::Float8_e4m3fnuz>() {
199 return "at::Float8_e4m3fnuz";
200 }
201
202 #define TYPE_NAME_CASE(ctype, scalartype) \
203 case ScalarType::scalartype: return typeName<ctype>();
typeName(ScalarType t)204 inline std::string typeName(ScalarType t) {
205 switch (t) {
206 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(TYPE_NAME_CASE)
207 default:
208 TORCH_CHECK(false, "invalid type for jiterator");
209 }
210 }
211 #undef TYPE_NAME_CASE
212
213 TORCH_CUDA_CPP_API void initializeCudaContext();
214
215 }}} // namespace at::cuda::jit
216