xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/code_template.h>
4 #include <torch/csrc/Export.h>
5 
6 namespace torch::jit::fuser::cuda {
7 
8 /*with type_as not checking type of its input, a fusion group can have non-fp32
9 tensor as input. Correct code for this case is generated, however, nvrtc does
10 not know how to handle int*_t integer types, so typedefs help it handle those
11 cases*/
12 
13 static constexpr auto bfloat16_type_string = "__nv_bfloat16";
14 
15 #if defined(USE_ROCM)
16 static auto type_declarations_template = at::jit::CodeTemplate(R"(
17 ${HalfHeader}
18 ${BFloat16Header}
19 ${RandHeader}
20 
21 #define NAN __int_as_float(0x7fffffff)
22 #define POS_INFINITY __int_as_float(0x7f800000)
23 #define NEG_INFINITY __int_as_float(0xff800000)
24 
25 typedef ${IndexType} IndexType;
26 template<typename T, size_t N>
27 struct TensorInfo {
28   T* data;
29   IndexType sizes[N];
30   IndexType strides[N];
31 };
32 template<typename T>
33 struct TensorInfo<T, 0> {
34   T * data;
35 };
36 )");
37 #else
38 static auto type_declarations_template = at::jit::CodeTemplate(R"(
39 typedef unsigned char uint8_t;
40 typedef signed char int8_t;
41 typedef short int  int16_t;
42 typedef long long int int64_t;
43 typedef unsigned long long int uint64_t;
44 ${HalfHeader}
45 ${BFloat16Header}
46 ${RandHeader}
47 
48 #define NAN __int_as_float(0x7fffffff)
49 #define POS_INFINITY __int_as_float(0x7f800000)
50 #define NEG_INFINITY __int_as_float(0xff800000)
51 
52 typedef ${IndexType} IndexType;
53 template<typename T, size_t N>
54 struct TensorInfo {
55   T* data;
56   IndexType sizes[N];
57   IndexType strides[N];
58 };
59 template<typename T>
60 struct TensorInfo<T, 0> {
61   T * data;
62 };
63 )");
64 #endif
65 
66 // We rewrite the code for philox RNG from curand as nvrtc couldn't resolve the
67 // curand header correctly.
68 constexpr auto rand_support_literal = R"(
69 
70   class Philox {
71   public:
72     __device__ inline Philox(unsigned long long seed,
73                              unsigned long long subsequence,
74                              unsigned long long offset) {
75       key.x = (unsigned int)seed;
76       key.y = (unsigned int)(seed >> 32);
77       counter = make_uint4(0, 0, 0, 0);
78       counter.z = (unsigned int)(subsequence);
79       counter.w = (unsigned int)(subsequence >> 32);
80       STATE = 0;
81       incr_n(offset / 4);
82     }
83 
84     __device__ inline unsigned long operator()() {
85       if(STATE == 0) {
86         uint4 counter_ = counter;
87         uint2 key_ = key;
88         for(int i = 0; i < 9; i++) {
89           counter_ = single_round(counter_, key_);
90           key_.x += (kPhilox10A); key_.y += (kPhilox10B);
91         }
92         output = single_round(counter_, key_);
93         incr();
94       }
95       unsigned long ret;
96       switch(STATE) {
97         case 0: ret = output.x; break;
98         case 1: ret = output.y; break;
99         case 2: ret = output.z; break;
100         case 3: ret = output.w; break;
101       }
102       STATE = (STATE + 1) % 4;
103       return ret;
104     }
105 
106   private:
107     uint4 counter;
108     uint4 output;
109     uint2 key;
110     unsigned int STATE;
111     __device__ inline void incr_n(unsigned long long n) {
112       unsigned int nlo = (unsigned int)(n);
113       unsigned int nhi = (unsigned int)(n >> 32);
114       counter.x += nlo;
115       if (counter.x < nlo)
116         nhi++;
117       counter.y += nhi;
118       if (nhi <= counter.y)
119         return;
120       if (++counter.z)
121         return;
122       ++counter.w;
123     }
124     __device__ inline void incr() {
125       if (++counter.x)
126         return;
127       if (++counter.y)
128         return;
129       if (++counter.z)
130         return;
131       ++counter.w;
132     }
133     __device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
134                                       unsigned int *result_high) {
135       *result_high = __umulhi(a, b);
136       return a*b;
137     }
138 
139     __device__ inline uint4 single_round(uint4 ctr, uint2 key) {
140       unsigned int hi0;
141       unsigned int hi1;
142       unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
143       unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
144 
145       uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
146       return ret;
147     }
148 
149     static const unsigned long kPhilox10A = 0x9E3779B9;
150     static const unsigned long kPhilox10B = 0xBB67AE85;
151     static const unsigned long kPhiloxSA = 0xD2511F53;
152     static const unsigned long kPhiloxSB = 0xCD9E8D57;
153   };
154 
155   // Inverse of 2^32.
156   #define M_RAN_INVM32 2.3283064e-10f
157   __device__  __inline__ float uniform(unsigned int x) {
158     return x * M_RAN_INVM32;
159   }
160 )";
161 
162 constexpr auto rand_param =
163     ",unsigned long long seed, unsigned long long offset";
164 
165 constexpr auto rand_init = R"(
166   int idx = blockIdx.x*blockDim.x + threadIdx.x;
167   Philox rnd(seed, idx, offset);
168 )";
169 
170 static auto cuda_compilation_unit_template = at::jit::CodeTemplate(R"(
171 ${type_declarations}
172 
173 extern "C" __global__
174 void ${kernelName}(IndexType totalElements, ${formals} ${RandParam}) {
175   ${RandInit}
176   // check whether do vectorized load/store and allocate buffer
177   bool flag_vec4 = true;
178   ${tensorChecks}
179   if (flag_vec4) {
180     for (IndexType linearIndex = 4 * (blockIdx.x * blockDim.x + threadIdx.x);
181          linearIndex < totalElements;
182          linearIndex += 4 * gridDim.x * blockDim.x) {
183       // Convert `linearIndex` into an offset of tensor as it is:
184       ${tensorOffsets}
185       // load 4 at a time
186       ${kernelLoad}
187       #pragma unroll 4
188       for (int i=0; i<4; i++) {
189         // calculate the results
190         ${kernelBody_vec4}
191       }
192       // store 4 at a time
193       ${kernelStore}
194     }
195   } else {
196     for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
197          linearIndex < totalElements;
198          linearIndex += gridDim.x * blockDim.x) {
199       // Convert `linearIndex` into an offset of tensor:
200       ${tensorOffsets}
201       // calculate the results
202       ${kernelBody}
203     }
204   }
205 }
206 )");
207 
208 // This snippet enables half support in the jit. Following the pattern for
209 // reductions, fp16 input data is immediately upconverted to float
210 // with __half2float(). All mathematical operations are done on float
211 // values, and if needed the intermediate float representation is
212 // converted to half with __float2half() when writing to a half tensor.
213 #if defined(USE_ROCM)
214 constexpr auto half_support_literal =
215     R"(
216 typedef __half half;
217 )";
218 #else
219 constexpr auto half_support_literal =
220     R"(
221 #define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
222 #define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var)))
223 #if defined(__cplusplus)
224   struct __align__(2) __half {
225     __host__ __device__ __half() { }
226 
227   protected:
228     unsigned short __x;
229   };
230 
231   /* All intrinsic functions are only available to nvcc compilers */
232   #if defined(__CUDACC__)
233     /* Definitions of intrinsics */
234     __device__ __half __float2half(const float f) {
235       __half val;
236       asm("{  cvt.rn.f16.f32 %0, %1;}\n" : "=h"(__HALF_TO_US(val)) : "f"(f));
237       return val;
238     }
239 
240     __device__ float __half2float(const __half h) {
241       float val;
242       asm("{  cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(__HALF_TO_CUS(h)));
243       return val;
244     }
245 )"
246     // MSVC's preprocessor (but not the standard compiler) has a bug
247     // where it incorrectly tokenizes raw string literals, ending when it sees a
248     // " this causes the #endif in this string literal to be treated as a
249     // preprocessor token which, in turn, cause sccache on windows CI to fail.
250     // See https://godbolt.org/z/eVTIJq as an example.
251     // This workaround uses string-pasting to separate the " and the #endif into
252     // different strings
253     R"(
254   #endif /* defined(__CUDACC__) */
255 #endif /* defined(__cplusplus) */
256 #undef __HALF_TO_US
257 #undef __HALF_TO_CUS
258 
259 typedef __half half;
260 )";
261 #endif
262 
263 #if defined(USE_ROCM)
264 constexpr auto bfloat16_support_literal =
265     R"(
266 #ifndef __align__
267 #define __align__(x) __attribute__((aligned(x)))
268 #endif
269 
270 typedef struct __align__(2) {
271   unsigned short x;
272 }
273 __nv_bfloat16_raw;
274 
275 #if defined(__cplusplus)
276 struct __align__(2) __nv_bfloat16 {
277   __host__ __device__ __nv_bfloat16() {}
278 
279   __host__ __device__ __nv_bfloat16& operator=(const __nv_bfloat16_raw& hr) {
280     __x = hr.x;
281     return *this;
282   }
283 
284   unsigned short __x;
285 };
286 
287 __device__ unsigned short __internal_float2bfloat16(
288     const float f,
289     unsigned int& sign,
290     unsigned int& remainder) {
291   unsigned int x;
292 
293   x = __float_as_uint(f);
294 
295   if ((x & 0x7fffffffU) > 0x7f800000U) {
296     sign = 0U;
297     remainder = 0U;
298     return static_cast<unsigned short>(0x7fffU);
299   }
300   sign = x >> 31;
301   remainder = x << 16;
302   return static_cast<unsigned short>(x >> 16);
303 }
304 
305 /* Definitions of intrinsics */
306 __device__ __nv_bfloat16 __float2bfloat16(const float a) {
307   __nv_bfloat16 val;
308   __nv_bfloat16_raw r;
309   unsigned int sign;
310   unsigned int remainder;
311   r.x = __internal_float2bfloat16(a, sign, remainder);
312   if ((remainder > 0x80000000U) ||
313       ((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) {
314     r.x++;
315   }
316   val = r;
317   return val;
318 }
319 
320 __device__ float __bfloat162float(const __nv_bfloat16 a) {
321   union
322   {
323       uint32_t int32;
324       float    fp32;
325   } u = {uint32_t(a.__x) << 16};
326   return u.fp32;
327 }
328 #endif /* defined(__cplusplus) */
329 )";
330 #else
331 constexpr auto bfloat16_support_literal =
332     R"(
333 #define __BFLOAT16_TO_US(var) *(reinterpret_cast<unsigned short*>(&(var)))
334 #define __BFLOAT16_TO_CUS(var) \
335   *(reinterpret_cast<const unsigned short*>(&(var)))
336 
337 typedef struct __align__(2) {
338   unsigned short x;
339 }
340 __nv_bfloat16_raw;
341 
342 #if defined(__cplusplus)
343 struct __align__(2) __nv_bfloat16 {
344   __host__ __device__ __nv_bfloat16() {}
345 
346   __host__ __device__ __nv_bfloat16& operator=(const __nv_bfloat16_raw& hr) {
347     __x = hr.x;
348     return *this;
349   }
350 
351  protected:
352   unsigned short __x;
353 };
354 
355 #if defined(__CUDACC__)
356 __device__ unsigned short __internal_float2bfloat16(
357     const float f,
358     unsigned int& sign,
359     unsigned int& remainder) {
360   unsigned int x;
361 
362   x = __float_as_uint(f);
363 
364   if ((x & 0x7fffffffU) > 0x7f800000U) {
365     sign = 0U;
366     remainder = 0U;
367     return static_cast<unsigned short>(0x7fffU);
368   }
369   sign = x >> 31;
370   remainder = x << 16;
371   return static_cast<unsigned short>(x >> 16);
372 }
373 
374 /* Definitions of intrinsics */
375 __device__ __nv_bfloat16 __float2bfloat16(const float a) {
376   __nv_bfloat16 val;
377 #if __CUDA_ARCH__ >= 800
378   asm("{  cvt.rn.bf16.f32 %0, %1;}\n" : "=h"(__BFLOAT16_TO_US(val)) : "f"(a));
379 #else
380   __nv_bfloat16_raw r;
381   unsigned int sign;
382   unsigned int remainder;
383   r.x = __internal_float2bfloat16(a, sign, remainder);
384   if ((remainder > 0x80000000U) ||
385       ((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) {
386     r.x++;
387   }
388   val = r;
389 #endif
390   return val;
391 }
392 
393 __device__ float __bfloat162float(const __nv_bfloat16 a) {
394   float val;
395   asm("{ mov.b32 %0, {0,%1};}\n" : "=f"(val) : "h"(__BFLOAT16_TO_CUS(a)));
396   return val;
397 }
398 #endif /* defined(__CUDACC__) */
399 #endif /* defined(__cplusplus) */
400 #undef __BFLOAT16_TO_US
401 #undef __BFLOAT16_TO_CUS
402 )";
403 #endif
404 
405 } // namespace torch::jit::fuser::cuda
406