1 #pragma once 2 3 #include <ATen/code_template.h> 4 #include <torch/csrc/Export.h> 5 6 namespace torch::jit::fuser::cuda { 7 8 /*with type_as not checking type of its input, a fusion group can have non-fp32 9 tensor as input. Correct code for this case is generated, however, nvrtc does 10 not know how to handle int*_t integer types, so typedefs help it handle those 11 cases*/ 12 13 static constexpr auto bfloat16_type_string = "__nv_bfloat16"; 14 15 #if defined(USE_ROCM) 16 static auto type_declarations_template = at::jit::CodeTemplate(R"( 17 ${HalfHeader} 18 ${BFloat16Header} 19 ${RandHeader} 20 21 #define NAN __int_as_float(0x7fffffff) 22 #define POS_INFINITY __int_as_float(0x7f800000) 23 #define NEG_INFINITY __int_as_float(0xff800000) 24 25 typedef ${IndexType} IndexType; 26 template<typename T, size_t N> 27 struct TensorInfo { 28 T* data; 29 IndexType sizes[N]; 30 IndexType strides[N]; 31 }; 32 template<typename T> 33 struct TensorInfo<T, 0> { 34 T * data; 35 }; 36 )"); 37 #else 38 static auto type_declarations_template = at::jit::CodeTemplate(R"( 39 typedef unsigned char uint8_t; 40 typedef signed char int8_t; 41 typedef short int int16_t; 42 typedef long long int int64_t; 43 typedef unsigned long long int uint64_t; 44 ${HalfHeader} 45 ${BFloat16Header} 46 ${RandHeader} 47 48 #define NAN __int_as_float(0x7fffffff) 49 #define POS_INFINITY __int_as_float(0x7f800000) 50 #define NEG_INFINITY __int_as_float(0xff800000) 51 52 typedef ${IndexType} IndexType; 53 template<typename T, size_t N> 54 struct TensorInfo { 55 T* data; 56 IndexType sizes[N]; 57 IndexType strides[N]; 58 }; 59 template<typename T> 60 struct TensorInfo<T, 0> { 61 T * data; 62 }; 63 )"); 64 #endif 65 66 // We rewrite the code for philox RNG from curand as nvrtc couldn't resolve the 67 // curand header correctly. 68 constexpr auto rand_support_literal = R"( 69 70 class Philox { 71 public: 72 __device__ inline Philox(unsigned long long seed, 73 unsigned long long subsequence, 74 unsigned long long offset) { 75 key.x = (unsigned int)seed; 76 key.y = (unsigned int)(seed >> 32); 77 counter = make_uint4(0, 0, 0, 0); 78 counter.z = (unsigned int)(subsequence); 79 counter.w = (unsigned int)(subsequence >> 32); 80 STATE = 0; 81 incr_n(offset / 4); 82 } 83 84 __device__ inline unsigned long operator()() { 85 if(STATE == 0) { 86 uint4 counter_ = counter; 87 uint2 key_ = key; 88 for(int i = 0; i < 9; i++) { 89 counter_ = single_round(counter_, key_); 90 key_.x += (kPhilox10A); key_.y += (kPhilox10B); 91 } 92 output = single_round(counter_, key_); 93 incr(); 94 } 95 unsigned long ret; 96 switch(STATE) { 97 case 0: ret = output.x; break; 98 case 1: ret = output.y; break; 99 case 2: ret = output.z; break; 100 case 3: ret = output.w; break; 101 } 102 STATE = (STATE + 1) % 4; 103 return ret; 104 } 105 106 private: 107 uint4 counter; 108 uint4 output; 109 uint2 key; 110 unsigned int STATE; 111 __device__ inline void incr_n(unsigned long long n) { 112 unsigned int nlo = (unsigned int)(n); 113 unsigned int nhi = (unsigned int)(n >> 32); 114 counter.x += nlo; 115 if (counter.x < nlo) 116 nhi++; 117 counter.y += nhi; 118 if (nhi <= counter.y) 119 return; 120 if (++counter.z) 121 return; 122 ++counter.w; 123 } 124 __device__ inline void incr() { 125 if (++counter.x) 126 return; 127 if (++counter.y) 128 return; 129 if (++counter.z) 130 return; 131 ++counter.w; 132 } 133 __device__ unsigned int mulhilo32(unsigned int a, unsigned int b, 134 unsigned int *result_high) { 135 *result_high = __umulhi(a, b); 136 return a*b; 137 } 138 139 __device__ inline uint4 single_round(uint4 ctr, uint2 key) { 140 unsigned int hi0; 141 unsigned int hi1; 142 unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0); 143 unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1); 144 145 uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; 146 return ret; 147 } 148 149 static const unsigned long kPhilox10A = 0x9E3779B9; 150 static const unsigned long kPhilox10B = 0xBB67AE85; 151 static const unsigned long kPhiloxSA = 0xD2511F53; 152 static const unsigned long kPhiloxSB = 0xCD9E8D57; 153 }; 154 155 // Inverse of 2^32. 156 #define M_RAN_INVM32 2.3283064e-10f 157 __device__ __inline__ float uniform(unsigned int x) { 158 return x * M_RAN_INVM32; 159 } 160 )"; 161 162 constexpr auto rand_param = 163 ",unsigned long long seed, unsigned long long offset"; 164 165 constexpr auto rand_init = R"( 166 int idx = blockIdx.x*blockDim.x + threadIdx.x; 167 Philox rnd(seed, idx, offset); 168 )"; 169 170 static auto cuda_compilation_unit_template = at::jit::CodeTemplate(R"( 171 ${type_declarations} 172 173 extern "C" __global__ 174 void ${kernelName}(IndexType totalElements, ${formals} ${RandParam}) { 175 ${RandInit} 176 // check whether do vectorized load/store and allocate buffer 177 bool flag_vec4 = true; 178 ${tensorChecks} 179 if (flag_vec4) { 180 for (IndexType linearIndex = 4 * (blockIdx.x * blockDim.x + threadIdx.x); 181 linearIndex < totalElements; 182 linearIndex += 4 * gridDim.x * blockDim.x) { 183 // Convert `linearIndex` into an offset of tensor as it is: 184 ${tensorOffsets} 185 // load 4 at a time 186 ${kernelLoad} 187 #pragma unroll 4 188 for (int i=0; i<4; i++) { 189 // calculate the results 190 ${kernelBody_vec4} 191 } 192 // store 4 at a time 193 ${kernelStore} 194 } 195 } else { 196 for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x; 197 linearIndex < totalElements; 198 linearIndex += gridDim.x * blockDim.x) { 199 // Convert `linearIndex` into an offset of tensor: 200 ${tensorOffsets} 201 // calculate the results 202 ${kernelBody} 203 } 204 } 205 } 206 )"); 207 208 // This snippet enables half support in the jit. Following the pattern for 209 // reductions, fp16 input data is immediately upconverted to float 210 // with __half2float(). All mathematical operations are done on float 211 // values, and if needed the intermediate float representation is 212 // converted to half with __float2half() when writing to a half tensor. 213 #if defined(USE_ROCM) 214 constexpr auto half_support_literal = 215 R"( 216 typedef __half half; 217 )"; 218 #else 219 constexpr auto half_support_literal = 220 R"( 221 #define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var))) 222 #define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var))) 223 #if defined(__cplusplus) 224 struct __align__(2) __half { 225 __host__ __device__ __half() { } 226 227 protected: 228 unsigned short __x; 229 }; 230 231 /* All intrinsic functions are only available to nvcc compilers */ 232 #if defined(__CUDACC__) 233 /* Definitions of intrinsics */ 234 __device__ __half __float2half(const float f) { 235 __half val; 236 asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(__HALF_TO_US(val)) : "f"(f)); 237 return val; 238 } 239 240 __device__ float __half2float(const __half h) { 241 float val; 242 asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(__HALF_TO_CUS(h))); 243 return val; 244 } 245 )" 246 // MSVC's preprocessor (but not the standard compiler) has a bug 247 // where it incorrectly tokenizes raw string literals, ending when it sees a 248 // " this causes the #endif in this string literal to be treated as a 249 // preprocessor token which, in turn, cause sccache on windows CI to fail. 250 // See https://godbolt.org/z/eVTIJq as an example. 251 // This workaround uses string-pasting to separate the " and the #endif into 252 // different strings 253 R"( 254 #endif /* defined(__CUDACC__) */ 255 #endif /* defined(__cplusplus) */ 256 #undef __HALF_TO_US 257 #undef __HALF_TO_CUS 258 259 typedef __half half; 260 )"; 261 #endif 262 263 #if defined(USE_ROCM) 264 constexpr auto bfloat16_support_literal = 265 R"( 266 #ifndef __align__ 267 #define __align__(x) __attribute__((aligned(x))) 268 #endif 269 270 typedef struct __align__(2) { 271 unsigned short x; 272 } 273 __nv_bfloat16_raw; 274 275 #if defined(__cplusplus) 276 struct __align__(2) __nv_bfloat16 { 277 __host__ __device__ __nv_bfloat16() {} 278 279 __host__ __device__ __nv_bfloat16& operator=(const __nv_bfloat16_raw& hr) { 280 __x = hr.x; 281 return *this; 282 } 283 284 unsigned short __x; 285 }; 286 287 __device__ unsigned short __internal_float2bfloat16( 288 const float f, 289 unsigned int& sign, 290 unsigned int& remainder) { 291 unsigned int x; 292 293 x = __float_as_uint(f); 294 295 if ((x & 0x7fffffffU) > 0x7f800000U) { 296 sign = 0U; 297 remainder = 0U; 298 return static_cast<unsigned short>(0x7fffU); 299 } 300 sign = x >> 31; 301 remainder = x << 16; 302 return static_cast<unsigned short>(x >> 16); 303 } 304 305 /* Definitions of intrinsics */ 306 __device__ __nv_bfloat16 __float2bfloat16(const float a) { 307 __nv_bfloat16 val; 308 __nv_bfloat16_raw r; 309 unsigned int sign; 310 unsigned int remainder; 311 r.x = __internal_float2bfloat16(a, sign, remainder); 312 if ((remainder > 0x80000000U) || 313 ((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) { 314 r.x++; 315 } 316 val = r; 317 return val; 318 } 319 320 __device__ float __bfloat162float(const __nv_bfloat16 a) { 321 union 322 { 323 uint32_t int32; 324 float fp32; 325 } u = {uint32_t(a.__x) << 16}; 326 return u.fp32; 327 } 328 #endif /* defined(__cplusplus) */ 329 )"; 330 #else 331 constexpr auto bfloat16_support_literal = 332 R"( 333 #define __BFLOAT16_TO_US(var) *(reinterpret_cast<unsigned short*>(&(var))) 334 #define __BFLOAT16_TO_CUS(var) \ 335 *(reinterpret_cast<const unsigned short*>(&(var))) 336 337 typedef struct __align__(2) { 338 unsigned short x; 339 } 340 __nv_bfloat16_raw; 341 342 #if defined(__cplusplus) 343 struct __align__(2) __nv_bfloat16 { 344 __host__ __device__ __nv_bfloat16() {} 345 346 __host__ __device__ __nv_bfloat16& operator=(const __nv_bfloat16_raw& hr) { 347 __x = hr.x; 348 return *this; 349 } 350 351 protected: 352 unsigned short __x; 353 }; 354 355 #if defined(__CUDACC__) 356 __device__ unsigned short __internal_float2bfloat16( 357 const float f, 358 unsigned int& sign, 359 unsigned int& remainder) { 360 unsigned int x; 361 362 x = __float_as_uint(f); 363 364 if ((x & 0x7fffffffU) > 0x7f800000U) { 365 sign = 0U; 366 remainder = 0U; 367 return static_cast<unsigned short>(0x7fffU); 368 } 369 sign = x >> 31; 370 remainder = x << 16; 371 return static_cast<unsigned short>(x >> 16); 372 } 373 374 /* Definitions of intrinsics */ 375 __device__ __nv_bfloat16 __float2bfloat16(const float a) { 376 __nv_bfloat16 val; 377 #if __CUDA_ARCH__ >= 800 378 asm("{ cvt.rn.bf16.f32 %0, %1;}\n" : "=h"(__BFLOAT16_TO_US(val)) : "f"(a)); 379 #else 380 __nv_bfloat16_raw r; 381 unsigned int sign; 382 unsigned int remainder; 383 r.x = __internal_float2bfloat16(a, sign, remainder); 384 if ((remainder > 0x80000000U) || 385 ((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) { 386 r.x++; 387 } 388 val = r; 389 #endif 390 return val; 391 } 392 393 __device__ float __bfloat162float(const __nv_bfloat16 a) { 394 float val; 395 asm("{ mov.b32 %0, {0,%1};}\n" : "=f"(val) : "h"(__BFLOAT16_TO_CUS(a))); 396 return val; 397 } 398 #endif /* defined(__CUDACC__) */ 399 #endif /* defined(__cplusplus) */ 400 #undef __BFLOAT16_TO_US 401 #undef __BFLOAT16_TO_CUS 402 )"; 403 #endif 404 405 } // namespace torch::jit::fuser::cuda 406