1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <c10/core/ScalarType.h>
3 #include <c10/util/irange.h>
4 #include <c10/util/hash.h>
5 #include <optional>
6 #include <ATen/jit_macros.h>
7 #include <ATen/cuda/CUDAContext.h>
8 #include <ATen/cuda/detail/OffsetCalculator.cuh>
9 #include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
10 #include <ATen/code_template.h>
11 #include <ATen/OpMathType.h>
12 #include <ATen/native/cuda/jit_utils.h>
13 #include <ATen/cuda/llvm_jit_strings.h>
14 #include <ATen/native/cuda/reduction_template.cuh>
15
16 #include <sstream>
17 #include <fstream>
18 #include <cstdio>
19 #include <iterator> // istreambuf_iterator
20 #include <cstdlib>
21 #include <string>
22
23 // TODO: C++17 has the filesystem header, which may replace these
24 #ifdef _WIN32
25 // On Windows, the POSIX implementations are considered deprecated. We simply map to the newer variant.
26 #include <process.h>
27 #include <direct.h>
28 #include <io.h>
29 #define access _access
30 #define getpid _getpid
31 #define R_OK 4
32 #define W_OK 2
33 #define F_OK 0
34 #else
35 #include <sys/types.h>
36 #include <sys/stat.h> // mkdir
37 #include <unistd.h>
38 #endif
39
40
41 namespace at::cuda::jit {
42
43 // hiprtc already includes some traits, so this removes duplicate definitions of
44 // integral_constant, is_same, is_integral, enable_if, is_floating_point, is_arithmetic.
45 // Copied from aten/src/ATen/cuda/llvm_basic.cpp, then modified as above.
46 // If not compiling for ROCm, return the original get_traits_string().
get_traits_string_but_hiprtc_safe()47 std::string get_traits_string_but_hiprtc_safe() {
48 #ifdef USE_ROCM
49 return R"ESCAPE(
50 namespace std {
51
52 template <class _Tp>
53 _Tp&& __declval(int);
54 template <class _Tp>
55 _Tp __declval(long);
56 template <class _Tp>
57 decltype(__declval<_Tp>(0)) declval() noexcept;
58
59 template <class _Tp> struct remove_const {typedef _Tp type;};
60 template <class _Tp> struct remove_const<const _Tp> {typedef _Tp type;};
61 template <class _Tp> using remove_const_t = typename remove_const<_Tp>::type;
62
63 template <class _Tp> struct remove_volatile {typedef _Tp type;};
64 template <class _Tp> struct remove_volatile<volatile _Tp> {typedef _Tp type;};
65 template <class _Tp> using remove_volatile_t = typename remove_volatile<_Tp>::type;
66
67 template <class _Tp> struct remove_cv
68 {typedef typename remove_volatile<typename remove_const<_Tp>::type>::type type;};
69 template <class _Tp> using remove_cv_t = typename remove_cv<_Tp>::type;
70
71 template <class _Tp> struct __libcpp_is_floating_point : public false_type {};
72 template <> struct __libcpp_is_floating_point<float> : public true_type {};
73 template <> struct __libcpp_is_floating_point<double> : public true_type {};
74 template <> struct __libcpp_is_floating_point<long double> : public true_type {};
75
76 template <class _Tp>
77 inline constexpr bool is_arithmetic_v = is_arithmetic<_Tp>::value;
78
79 template <class _Tp>
80 struct __numeric_type
81 {
82 static void __test(...);
83 static float __test(float);
84 static double __test(char);
85 static double __test(int);
86 static double __test(unsigned);
87 static double __test(long);
88 static double __test(unsigned long);
89 static double __test(long long);
90 static double __test(unsigned long long);
91 static double __test(double);
92 static long double __test(long double);
93
94 typedef decltype(__test(declval<_Tp>())) type;
95 static const bool value = !is_same<type, void>::value;
96 };
97
98 template <>
99 struct __numeric_type<void>
100 {
101 static const bool value = true;
102 };
103
104 // __promote
105
106 template <class _A1, class _A2 = void, class _A3 = void,
107 bool = __numeric_type<_A1>::value &&
108 __numeric_type<_A2>::value &&
109 __numeric_type<_A3>::value>
110 class __promote_imp
111 {
112 public:
113 static const bool value = false;
114 };
115
116 template <class _A1, class _A2, class _A3>
117 class __promote_imp<_A1, _A2, _A3, true>
118 {
119 private:
120 typedef typename __promote_imp<_A1>::type __type1;
121 typedef typename __promote_imp<_A2>::type __type2;
122 typedef typename __promote_imp<_A3>::type __type3;
123 public:
124 typedef decltype(__type1() + __type2() + __type3()) type;
125 static const bool value = true;
126 };
127
128 template <class _A1, class _A2>
129 class __promote_imp<_A1, _A2, void, true>
130 {
131 private:
132 typedef typename __promote_imp<_A1>::type __type1;
133 typedef typename __promote_imp<_A2>::type __type2;
134 public:
135 typedef decltype(__type1() + __type2()) type;
136 static const bool value = true;
137 };
138
139 template <class _A1>
140 class __promote_imp<_A1, void, void, true>
141 {
142 public:
143 typedef typename __numeric_type<_A1>::type type;
144 static const bool value = true;
145 };
146
147 template <class _A1, class _A2 = void, class _A3 = void>
148 class __promote : public __promote_imp<_A1, _A2, _A3> {};
149
150 } // namespace std
151 )ESCAPE";
152 #else
153 return get_traits_string();
154 #endif
155 }
156
157 #ifdef USE_ROCM
158 const std::string jit_preamble = R"ESCAPE(
159 #pragma clang force_cuda_host_device begin
160 )ESCAPE";
161 const std::string jit_epilogue = R"ESCAPE(
162 #pragma clang force_cuda_host_device end
163 )ESCAPE";
164 #else
165 const std::string jit_preamble;
166 const std::string jit_epilogue;
167 #endif
168
169 const std::string jit_common_types = R"ESCAPE(
170 #ifdef __HIPCC__
171 #define ERROR_UNSUPPORTED_CAST ;
172 // corresponds to aten/src/ATen/native/cuda/thread_constants.h
173 #define CUDA_OR_ROCM_NUM_THREADS 256
174 // corresponds to aten/src/ATen/cuda/detail/OffsetCalculator.cuh
175 #define MAX_DIMS 16
176 #ifndef __forceinline__
177 #define __forceinline__ inline __attribute__((always_inline))
178 #endif
179 #else
180 //TODO use _assert_fail, because assert is disabled in non-debug builds
181 #define ERROR_UNSUPPORTED_CAST assert(false);
182 #define CUDA_OR_ROCM_NUM_THREADS 128
183 #define MAX_DIMS 25
184 #endif
185 #define POS_INFINITY __int_as_float(0x7f800000)
186 #define INFINITY POS_INFINITY
187 #define NEG_INFINITY __int_as_float(0xff800000)
188 #define NAN __int_as_float(0x7fffffff)
189
190 typedef long long int int64_t;
191 typedef unsigned int uint32_t;
192 typedef signed char int8_t;
193 typedef unsigned char uint8_t; // NOTE: this MUST be "unsigned char"! "char" is equivalent to "signed char"
194 typedef short int16_t;
195 static_assert(sizeof(int64_t) == 8, "expected size does not match");
196 static_assert(sizeof(uint32_t) == 4, "expected size does not match");
197 static_assert(sizeof(int8_t) == 1, "expected size does not match");
198 constexpr int num_threads = CUDA_OR_ROCM_NUM_THREADS;
199 constexpr int thread_work_size = 4; // TODO: make template substitution once we decide where those vars live
200 constexpr int block_work_size = thread_work_size * num_threads;
201
202 ${traits_string}
203 ${cmath_string}
204
205 // NB: Order matters for this macro; it is relied upon in
206 // _promoteTypesLookup and the serialization format.
207 // Note, some types have ctype as void because we don't support them in codegen
208 #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
209 _(uint8_t, Byte) /* 0 */ \
210 _(int8_t, Char) /* 1 */ \
211 _(int16_t, Short) /* 2 */ \
212 _(int, Int) /* 3 */ \
213 _(int64_t, Long) /* 4 */ \
214 _(at::Half, Half) /* 5 */ \
215 _(float, Float) /* 6 */ \
216 _(double, Double) /* 7 */ \
217 _(std::complex<at::Half>, ComplexHalf) /* 8 */ \
218 _(std::complex<float>, ComplexFloat) /* 9 */ \
219 _(std::complex<double>, ComplexDouble) /* 10 */ \
220 _(bool, Bool) /* 11 */ \
221 _(void, QInt8) /* 12 */ \
222 _(void, QUInt8) /* 13 */ \
223 _(void, QInt32) /* 14 */ \
224 _(at::BFloat16, BFloat16) /* 15 */ \
225
226 #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_QINT(_) \
227 _(uint8_t, Byte) \
228 _(int8_t, Char) \
229 _(int16_t, Short) \
230 _(int, Int) \
231 _(int64_t, Long) \
232 _(at::Half, Half) \
233 _(float, Float) \
234 _(double, Double) \
235 _(std::complex<at::Half>, ComplexHalf) \
236 _(std::complex<float>, ComplexFloat) \
237 _(std::complex<double>, ComplexDouble) \
238 _(bool, Bool) \
239 _(at::BFloat16, BFloat16)
240
241
242 enum class ScalarType : int8_t {
243 #define DEFINE_ENUM(_1, n) n,
244 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ENUM)
245 #undef DEFINE_ENUM
246 Undefined,
247 NumOptions
248 };
249
250 template <typename T, int size>
251 struct Array {
252 T data[size];
253
254 __device__ T operator[](int i) const {
255 return data[i];
256 }
257 __device__ T& operator[](int i) {
258 return data[i];
259 }
260 Array() = default;
261 Array(const Array&) = default;
262 Array& operator=(const Array&) = default;
263 __device__ Array(T x) {
264 for (int i = 0; i < size; i++) {
265 data[i] = x;
266 }
267 }
268 };
269
270 ${half_string}
271 ${bfloat16_string}
272 ${complex_body_string}
273 ${complex_half_body_string}
274 ${complex_math_string}
275
276
277 )ESCAPE";
278
279 //we need to include half, bfloat16 and complex strings to all kernels with half arguments and to all kernels with type casting
280 //regardless of whether they have half arguments (because fetch_and_cast and cast_and_store loop over all types)
281 const std::string jiterator_half_support_literal = R"ESCAPE(
282 namespace at {
283 struct alignas(2) Half {
284 unsigned short x;
285
286 Half() = default;
287 inline __host__ __device__ Half(float value){
288 #ifdef __HIPCC__
289 x = __half_as_short(__float2half(value));
290 #else
291 asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(x) : "f"(value));
292 #endif
293 }
294 inline __host__ __device__ operator float() const{
295 #ifdef __HIPCC__
296 return __half2float(*reinterpret_cast<const __half*>(&x));
297 #else
298 float val;
299 asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(x)); // do we need const cast here?
300 //asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(__HALF_TO_CUS(x)));
301 return val;
302 #endif
303 }
304 };
305 }
306 )ESCAPE";
307
308 const std::string jiterator_bfloat16_support_literal = R"ESCAPE(
309 namespace at {
310 struct alignas(2) BFloat16 {
311 unsigned short x;
312
313 __device__ unsigned short __internal_float2bfloat16(
314 const float f,
315 unsigned int& sign,
316 unsigned int& remainder) {
317 unsigned int x;
318
319 x = __float_as_uint(f);
320
321 if ((x & 0x7fffffffU) > 0x7f800000U) {
322 sign = 0U;
323 remainder = 0U;
324 return static_cast<unsigned short>(0x7fffU);
325 }
326 sign = x >> 31;
327 remainder = x << 16;
328 return static_cast<unsigned short>(x >> 16);
329 }
330
331
332 BFloat16() = default;
333 inline __host__ __device__ BFloat16(float value){
334 #if __CUDA_ARCH__ >= 800
335 asm("{ cvt.rn.bf16.f32 %0, %1;}\n" : "=h"(x) : "f"(value));
336 )ESCAPE"
337 R"ESCAPE(
338 #else
339 unsigned int sign;
340 unsigned int remainder;
341 x = __internal_float2bfloat16(value, sign, remainder);
342 if ((remainder > 0x80000000U) ||
343 ((remainder == 0x80000000U) && ((x & 0x1U) != 0U))) {
344 x++;
345 }
346 #endif
347 }
348
349 inline __host__ __device__ operator float() const{
350 #ifdef __HIPCC__
351 union
352 {
353 uint32_t int32;
354 float fp32;
355 } u = {uint32_t(x) << 16};
356 return u.fp32;
357 #else
358 float val;
359 asm("{ mov.b32 %0, {0,%1};}\n" : "=f"(val) : "h"(x)); //do we need const cast here?
360 return val;
361 #endif
362 }
363
364 };
365 }
366 )ESCAPE";
367
368 // From c10/util/Load.h
369 const std::string load_support_literal = R"ESCAPE(
370
371 namespace c10 {
372 template <typename T>
373 struct LoadImpl {
374 __device__ static T apply(const void *src) {
375 return *reinterpret_cast<const T*>(src);
376 }
377 };
378
379 template <>
380 struct LoadImpl<bool> {
381 __device__ static bool apply(const void *src) {
382 static_assert(sizeof(bool) == sizeof(char), "");
383 return LoadImpl<char>::apply(src);
384 }
385 };
386
387 template <typename T>
388 __device__ T load(const void *src) {
389 return LoadImpl<T>::apply(src);
390 }
391
392 template <typename scalar_t>
393 __device__ scalar_t load(const scalar_t *src) {
394 return LoadImpl<scalar_t>::apply(src);
395 }
396 } // namespace c10
397
398 )ESCAPE";
399
400 // copy-pasted from c10/util/TypeCast.h and c10/core/DynamicCast.h
401 const std::string dynamic_cast_support_literal = R"ESCAPE(
402
403 template <typename T>
404 struct is_complex : public std::false_type {};
405
406 template <typename T>
407 struct is_complex<std::complex<T>> : public std::true_type {};
408
409 template <typename dest_t, typename src_t>
410 struct needs_real {
411 constexpr static bool value =
412 (is_complex<src_t>::value && !is_complex<dest_t>::value);
413 };
414
415 template <bool, typename src_t>
416 struct maybe_real {
417 static inline src_t apply(src_t src) {
418 return src;
419 }
420 };
421
422 template <typename src_t>
423 struct maybe_real<true, src_t> {
424 static inline decltype(auto) apply(src_t src) {
425 return src.real();
426 }
427 };
428
429 template <typename dest_t, typename src_t>
430 struct static_cast_with_inter_type {
431 static inline dest_t apply(
432 src_t src) {
433 constexpr bool real = needs_real<dest_t, src_t>::value;
434 return static_cast<dest_t>(maybe_real<real, src_t>::apply(src));
435 }
436 };
437
438 template <typename src_t>
439 struct static_cast_with_inter_type<uint8_t, src_t> {
440 static inline uint8_t apply(
441 src_t src) {
442 constexpr bool real = needs_real<uint8_t, src_t>::value;
443 return static_cast<uint8_t>(
444 static_cast<int64_t>(maybe_real<real, src_t>::apply(src)));
445 }
446 };
447
448 template <>
449 struct static_cast_with_inter_type<std::complex<at::Half>, at::BFloat16> {
450 static inline std::complex<at::Half> apply(at::BFloat16 src) {
451 return static_cast<std::complex<at::Half>>(float{src});
452 }
453 };
454
455 template <>
456 struct static_cast_with_inter_type<std::complex<at::Half>, at::Half> {
457 static inline std::complex<at::Half> apply(at::Half src) {
458 return static_cast<std::complex<at::Half>>(float{src});
459 }
460 };
461
462 template <>
463 struct static_cast_with_inter_type<
464 std::complex<at::Half>,
465 std::complex<double>> {
466 static inline std::complex<at::Half> apply(std::complex<double> src) {
467 return static_cast<std::complex<at::Half>>(static_cast<std::complex<float>>(src));
468 }
469 };
470
471 // Fetch a value with dynamic type src_type from ptr, and cast it to static type dest_t.
472 #define FETCH_AND_CAST_CASE(type, scalartype) \
473 case ScalarType::scalartype: \
474 return static_cast_with_inter_type<dest_t, type>::apply(c10::load<type>(ptr));
475 template<typename dest_t>
476 __device__ inline dest_t fetch_and_cast(const ScalarType src_type, const void *ptr) {
477 switch (src_type) {
478 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_QINT(FETCH_AND_CAST_CASE)
479 default:
480 ERROR_UNSUPPORTED_CAST
481 }
482 return dest_t(0); // just to avoid compiler warning
483 }
484
485 // Cast a value with static type src_t into dynamic dest_type, and store it to ptr.
486 #define CAST_AND_STORE_CASE(type, scalartype) \
487 case ScalarType::scalartype: \
488 *(type*)ptr = static_cast_with_inter_type<type, src_t>::apply(value); \
489 return;
490 template<typename src_t>
491 __device__ inline void cast_and_store(const ScalarType dest_type, void *ptr, src_t value) {
492 switch (dest_type) {
493 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_QINT(CAST_AND_STORE_CASE)
494 default:;
495 }
496 ERROR_UNSUPPORTED_CAST
497 }
498
499 template <int N>
500 struct LoadWithCast {
501 using array_t = Array<ScalarType, N==0? 1 : N>;
502 using size_array_t = Array<uint32_t, N==0? 1: N>;
503
504 array_t dtypes;
505 size_array_t element_sizes;
506 template <typename scalar_t>
507 __device__ scalar_t load(char* base_ptr, uint32_t offset, int arg) {
508 void* ptr = base_ptr + element_sizes[arg] * offset;
509 return fetch_and_cast<scalar_t>(dtypes[arg], ptr);
510 }
511 };
512
513 template <int N = 1>
514 struct StoreWithCast {
515 using array_t = Array<ScalarType, N==0? 1 : N>;
516 using size_array_t = Array<uint32_t, N==0? 1: N>;
517
518 array_t dtypes;
519 size_array_t element_sizes;
520
521 template<typename scalar_t>
522 __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) {
523 void *ptr = base_ptr + element_sizes[arg] * offset;
524 cast_and_store<scalar_t>(dtypes[arg], ptr, value);
525 }
526 };
527
528 )ESCAPE";
529
530 const std::string no_dynamic_cast_support_literal = R"ESCAPE(
531
532 struct LoadWithoutCast {
533 template <typename scalar_t>
534 __device__ scalar_t load(char* base_ptr, uint32_t offset, int arg=0) {
535 return c10::load(reinterpret_cast<scalar_t*>(base_ptr) + offset);
536 }
537 };
538
539 struct StoreWithoutCast {
540 template<typename scalar_t>
541 __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg=0) {
542 *(reinterpret_cast<scalar_t *>(base_ptr) + offset) = value;
543 }
544 };
545
546 )ESCAPE";
547
548 const std::string offset_calc_template = R"ESCAPE(
549 template <typename T>
550 struct DivMod {
551 T div;
552 T mod;
553
554 __device__ DivMod(T _div, T _mod) {
555 div = _div;
556 mod = _mod;
557 }
558 };
559
560 //<unsigned int>
561 struct IntDivider {
562 IntDivider() = default;
563
564 __device__ inline unsigned int div(unsigned int n) const {
565 unsigned int t = __umulhi(n, m1);
566 return (t + n) >> shift;
567 }
568
569 __device__ inline unsigned int mod(unsigned int n) const {
570 return n - div(n) * divisor;
571 }
572
573 __device__ inline DivMod<unsigned int> divmod(unsigned int n) const {
574 unsigned int q = div(n);
575 return DivMod<unsigned int>(q, n - q * divisor);
576 }
577
578 unsigned int divisor; // d above.
579 unsigned int m1; // Magic number: m' above.
580 unsigned int shift; // Shift amounts.
581 };
582
583 template <int NARGS>
584 struct TrivialOffsetCalculator {
585 // The offset for each argument. Wrapper around fixed-size array.
586 // The offsets are in # of elements, not in bytes.
587 Array<${index_type}, NARGS> get(${index_type} linear_idx) const {
588 Array<${index_type}, NARGS> offsets;
589 #pragma unroll
590 for (int arg = 0; arg < NARGS; arg++) {
591 offsets[arg] = linear_idx;
592 }
593 return offsets;
594 }
595 };
596
597 template<int NARGS>
598 struct OffsetCalculator {
599 OffsetCalculator() = default;
600 __device__ __forceinline__ Array<${index_type}, NARGS> get(${index_type} linear_idx) const {
601 Array<${index_type}, NARGS> offsets;
602 #pragma unroll
603 for (int arg = 0; arg < NARGS; ++arg) {
604 offsets[arg] = 0;
605 }
606
607 #pragma unroll
608 for (int dim = 0; dim < MAX_DIMS; ++dim) {
609 if (dim == dims) {
610 break;
611 }
612
613 auto divmod = sizes_[dim].divmod(linear_idx);
614 linear_idx = divmod.div;
615
616 #pragma unroll
617 for (int arg = 0; arg < NARGS; ++arg) {
618 offsets[arg] += divmod.mod * strides_[dim][arg];
619 }
620 //printf("offset calc thread dim size stride offset %d %d %d %d %d %d %d %d\n",
621 //threadIdx.x, dim, sizes_[dim].divisor, strides_[dim][0], offsets[0], linear_idx, divmod.div, divmod.mod);
622 }
623 return offsets;
624 }
625
626 int dims;
627 IntDivider sizes_[MAX_DIMS];
628 // NOTE: this approach will not support nInputs == 0
629 ${index_type} strides_[MAX_DIMS][NARGS];
630 };
631
632
633 )ESCAPE";
634
635 const std::string jit_code_template = R"ESCAPE(
636
637 ${load_support}
638 ${dynamic_casting_string}
639
640
641 ${functor}
642
643 // TODO: setup grid-stride loop
644 extern "C" __global__
645 void ${name}_kernel(
646 const int numel,
647 Array<char*, ${nInputs}+${nOutputs}> data, //[${nInputs}+${nOutputs}],
648 ${offset_calculator}<${nInputs}> input_calculator,
649 ${offset_calculator}<${nOutputs}> output_calculator,
650 ${loader} l,
651 ${storer} s,
652 ${compute_type} scalar_val${extra_params}) {
653 ${declare_load_arrays}
654 ${declare_store_arrays}
655
656 int idx = blockIdx.x;
657
658 int remaining = numel - block_work_size * idx;
659 int thread_idx = threadIdx.x;
660
661 #pragma unroll
662 for (int j = 0; j < thread_work_size; j++){
663 if (thread_idx >= remaining) {
664 break;
665 }
666
667 int linear_idx = thread_idx + block_work_size * idx;
668 auto input_offsets = input_calculator.get(linear_idx);
669 ${load_inputs}
670 // printf(
671 // "thread %d a %f offsets %d\n", threadIdx.x, arg0[j], input_offsets[0]);
672 thread_idx += num_threads;
673 }
674
675 #pragma unroll
676 for (int j = 0; j < thread_work_size; j++) {
677 if ((threadIdx.x + j*num_threads) < remaining) {
678 ${call_functor}
679 }
680 }
681
682 thread_idx = threadIdx.x;
683 #pragma unroll
684 for (int j = 0; j < thread_work_size; j++){
685 if (thread_idx >= remaining) {
686 break;
687 }
688 //TODO maybe think about unifying offset calculators and reuse
689 //offsets computed in the load loop
690 int linear_idx = thread_idx + block_work_size * idx;
691 auto output_offsets = output_calculator.get(linear_idx);
692 //printf("output thread %d offset %d\n", threadIdx.x, output_offsets[0]);
693 ${store_outputs}
694 thread_idx += num_threads;
695 }
696 }
697 )ESCAPE";
698
699 const std::string jit_vectorized_code_template = R"ESCAPE(
700
701 ${load_support}
702
703 template <typename scalar_t>
704 __device__ __inline__ scalar_t load(char* base_ptr, uint32_t offset) {
705 return c10::load(reinterpret_cast<scalar_t*>(base_ptr) + offset);
706 }
707
708 template<typename scalar_t>
709 __device__ __inline__ void store(scalar_t value, char *base_ptr, uint32_t offset) {
710 *(reinterpret_cast<scalar_t *>(base_ptr) + offset) = value;
711 }
712
713 // aligned vector generates vectorized load/store on CUDA
714 template<typename scalar_t, int vec_size>
715 struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
716 scalar_t val[vec_size];
717 };
718
719 template <int vec_size, typename scalar_t>
720 __device__ aligned_vector<scalar_t, vec_size> load_vector(const scalar_t *base_ptr, uint32_t offset) {
721 using vec_t = aligned_vector<scalar_t, vec_size>;
722 auto *from = reinterpret_cast<const vec_t *>(base_ptr);
723 return from[offset];
724 }
725
726 template <int vec_size>
727 __device__ aligned_vector<bool, vec_size> load_vector(const bool *base_ptr, uint32_t offset) {
728 // See NOTE [Loading boolean values]
729 auto tmp = load_vector<vec_size>(reinterpret_cast<const uint8_t*>(base_ptr), offset);
730 aligned_vector<bool, vec_size> ret;
731 for (int i = 0; i < vec_size; ++i) {
732 ret.val[i] = bool(tmp.val[i]);
733 }
734 return ret;
735 }
736
737 ${functor}
738
739 // TODO: setup grid-stride loop
740
741 extern "C" __global__
742 void ${name}_vectorized${vec_size}_kernel(
743 const int N,
744 Array<char*, ${nInputs}+${nOutputs}> data,
745 ${compute_type} scalar_val${extra_params}) //[${nInputs}+${nOutputs}],
746 {
747 constexpr int vec_size = ${vec_size};
748 using scalar_t = ${scalar_type};
749 int remaining = N - block_work_size * blockIdx.x;
750 int thread_idx = threadIdx.x;
751 int idx = blockIdx.x;
752 ${declare_load_arrays}
753 ${declare_store_arrays}
754
755 if (remaining < block_work_size) {
756 #pragma unroll
757 for (int j = 0; j < thread_work_size; j++){
758 if (thread_idx >= remaining) {
759 break;
760 }
761 int linear_idx = thread_idx + block_work_size * idx;
762 ${load_unrolled_inputs}
763 thread_idx += num_threads;
764 }
765 #pragma unroll
766 for (int j = 0; j < thread_work_size; j++) {
767 if ((threadIdx.x + j*num_threads) < remaining) {
768 ${call_functor}
769 }
770 }
771 thread_idx = threadIdx.x;
772 #pragma unroll
773 for (int j = 0; j < thread_work_size; j++) {
774 if (thread_idx >= remaining) {
775 break;
776 }
777 int linear_idx = thread_idx + block_work_size * idx;
778 ${store_unrolled_outputs}
779 thread_idx += num_threads;
780 }
781 } else {
782 static constexpr int loop_size = thread_work_size / vec_size;
783 //actual loading
784 ${vector_inputs}
785 #pragma unroll
786 for (int i = 0; i<loop_size; i++){
787 ${load_vectorized_inputs}
788 thread_idx += num_threads;
789 }
790
791 #pragma unroll
792 for (int j = 0; j < thread_work_size; j++) {
793 ${call_functor}
794 }
795
796 using vec_t_output = aligned_vector<${result_type}, vec_size>;
797 ${vector_outputs}
798 int thread_idx = threadIdx.x;
799 #pragma unroll
800 for (int i = 0; i<loop_size; i++){
801 vec_t_output v;
802 ${store_vectorized_outputs}
803 thread_idx += num_threads;
804 }
805 }
806 }
807 )ESCAPE";
808
replace_all(std::string & s,const std::string & to_replace,const std::string & replace_with)809 static void replace_all(std::string& s, const std::string& to_replace, const std::string& replace_with) {
810 std::ostringstream oss;
811 std::size_t pos = 0;
812 std::size_t prev_pos = pos;
813
814 while (true) {
815 prev_pos = pos;
816 pos = s.find(to_replace, pos);
817 if (pos == std::string::npos)
818 break;
819 oss << s.substr(prev_pos, pos - prev_pos);
820 oss << replace_with;
821 pos += to_replace.size();
822 }
823
824 oss << s.substr(prev_pos);
825 s = oss.str();
826 }
827
828 // hipify replaces certain device math functions, e.g., std::max -> ::max
829 // See torch/utils/hipify/cuda_to_hip_mappings.py.
830 // Replace them back. Search for " ::<name>" to avoid duplicate replacements.
unhipify_math_functions(const std::string & original)831 static std::string unhipify_math_functions(const std::string &original) {
832 static std::vector<std::pair<std::string,std::string>> mappings = {
833 {" std::max", " ::max"},
834 {" std::min", " ::min"},
835 {" std::ceil", " ::ceil"},
836 {" std::floor", " ::floor"},
837 {" std::exp", " ::exp"},
838 {" std::log", " ::log"},
839 {" std::pow", " ::pow"},
840 {" std::fabs", " ::fabs"},
841 {" std::fmod", " ::fmod"},
842 {" std::remainder", " ::remainder"},
843 {" std::frexp", " ::frexp"}
844 };
845 std::string ret = original;
846 for (const auto& mapping : mappings) {
847 replace_all(ret, mapping.second, mapping.first);
848 }
849 return ret;
850 }
851
852 // The following is copied from fused_kernel.cpp
853 // TODO: refactor codegenOutputQuery into its own file
854 // that can be included by both files
855 // See NOTE [ USE OF NVRTC AND DRIVER API ]
nvrtc()856 const at::cuda::NVRTC& nvrtc() {
857 return at::globalContext().getNVRTC();
858 }
859
860 // query codegen output arch and target
861 // TODO refactor so this function is usable both from jit and from aten
codegenOutputQuery(const cudaDeviceProp * const prop,int & cuda_major,int & cuda_minor,int & nvrtc_major,int & nvrtc_minor,bool & compile_to_sass)862 void codegenOutputQuery(
863 const cudaDeviceProp* const prop,
864 int& cuda_major,
865 int& cuda_minor,
866 int& nvrtc_major,
867 int& nvrtc_minor,
868 bool& compile_to_sass) {
869 #ifdef USE_ROCM
870 AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcVersion(&nvrtc_major, &nvrtc_minor));
871 cuda_major = prop->major;
872 cuda_minor = prop->minor;
873 compile_to_sass = false;
874 #else
875 AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcVersion(&nvrtc_major, &nvrtc_minor));
876 TORCH_CHECK(
877 nvrtc_major >= 6, "NVRTC versions less than 6 are not supported. Is: ", nvrtc_major);
878
879 // Version supported by device
880 // Usually any lower version works too but is less efficient
881 using CUDAVersion = std::pair<int, int>;
882 const CUDAVersion nvrtc_version{nvrtc_major, nvrtc_minor};
883 const CUDAVersion dev_version{prop->major, prop->minor};
884 // Maximum version supported by the driver, cap dev_version to this
885 CUDAVersion max_dev_version;
886 if (nvrtc_major <= 7) { // 7 supports 2-5.x
887 max_dev_version = CUDAVersion(5, 0);
888 } else if (nvrtc_major <= 8) { // 8 supports 2-6.x
889 max_dev_version = CUDAVersion(6, 0);
890 } else if (nvrtc_major <= 9) { // 9 supports 3-7.2
891 max_dev_version = CUDAVersion(7, 2);
892 } else if (nvrtc_major <= 10) { // 10 supports 3-7.5
893 max_dev_version = CUDAVersion(7, 5);
894 } else if (nvrtc_version == CUDAVersion(11, 0)) { // 11.0 supports 3-8.0
895 max_dev_version = CUDAVersion(8, 0);
896 } else if (nvrtc_major == 11 && nvrtc_minor < 8) {
897 max_dev_version = CUDAVersion(8, 6);
898 } else {
899 // If the driver version is unknown (i.e. newer than this code)
900 // assume the driver supports this device
901 max_dev_version = dev_version;
902 }
903
904 if (dev_version > max_dev_version) {
905 cuda_major = max_dev_version.first;
906 cuda_minor = max_dev_version.second;
907 // if we are clamping major/minor, sass is not compatible
908 compile_to_sass = false;
909 } else {
910 cuda_major = dev_version.first;
911 cuda_minor = dev_version.second;
912 compile_to_sass = true;
913 }
914
915 #if defined(CUDA_VERSION) && CUDA_VERSION < 11010
916 // compile to sass is not allowed prior to CUDA 11.1
917 compile_to_sass = false;
918 #endif
919 #endif
920 }
921
922 // TODO: another copy paste from jit, refactor so it's usable from both
923 // TODO: try making the CUcontext thread local to see if that improves performance - why is this slow?
initializeCudaContext()924 void initializeCudaContext() {
925 // lazily construct context if non-existing yet;
926 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
927 CUcontext pctx = nullptr;
928 AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx));
929 if (!pctx) {
930 std::unique_lock<std::mutex> cudaFreeMutexLock(
931 *(c10::cuda::getFreeMutex()));
932 cudaFree(nullptr);
933 }
934 }
935
generate_code(const KernelDescriptor & desc,bool contiguous,bool dynamic_casting,BinaryFuncVariant scalar_pos,bool vectorized,int vec_size,bool return_by_ref)936 std::string generate_code(
937 const KernelDescriptor &desc,
938 bool contiguous,
939 bool dynamic_casting,
940 BinaryFuncVariant scalar_pos,
941 bool vectorized,
942 int vec_size,
943 bool return_by_ref) {
944 c10::SmallVector<std::string> extra_args_typenames(desc.extra_args_types.size());
945 for (auto i : c10::irange(extra_args_typenames.size())) {
946 extra_args_typenames[i] = typeName(desc.extra_args_types[i]);
947 }
948
949 return generate_code(
950 desc.nInputs,
951 desc.nOutputs,
952 desc.f,
953 desc.name,
954 typeName(desc.f_inputs_type),
955 typeName(toOpMathType(desc.f_inputs_type)),
956 typeName(desc.result_type),
957 contiguous,
958 dynamic_casting,
959 scalar_pos,
960 extra_args_typenames,
961 vectorized,
962 vec_size,
963 return_by_ref);
964 }
965
966 //FIXME - this are defined in Loops.cuh, but including Loops.cuh here would lead to circular includes Loops.cuh -> CUDALoops.cuh -> jit_utils.h -> Loops.cuh
967 #define THREAD_WORK_SIZE 4
968 constexpr int thread_work_size = THREAD_WORK_SIZE;
969
generate_code(int nInputs,int nOutputs,const std::string & func_,const std::string & name,const std::string & f_inputs_type,const std::string & compute_type,const std::string & result_type,bool contiguous,bool dynamic_casting,BinaryFuncVariant scalar_pos,c10::SmallVector<std::string> & extra_args_typenames,bool vectorized,int vec_size,bool return_by_ref)970 std::string generate_code(
971 int nInputs,
972 int nOutputs,
973 const std::string& func_,
974 const std::string& name,
975 const std::string& f_inputs_type,
976 const std::string& compute_type,
977 const std::string& result_type,
978 bool contiguous,
979 bool dynamic_casting,
980 BinaryFuncVariant scalar_pos,
981 c10::SmallVector<std::string>& extra_args_typenames,
982 bool vectorized,
983 int vec_size,
984 bool return_by_ref) {
985 std::string func = func_;
986 at::jit::TemplateEnv env;
987
988 env.s("index_type", "unsigned int");
989 env.s("nInputs", std::to_string(nInputs));
990 env.s("nOutputs", std::to_string(nOutputs));
991 env.s("scalar_type", f_inputs_type);
992 env.s("compute_type", compute_type);
993 env.s("functor", func);
994 env.s("name", name);
995 env.s("cmath_string", get_cmath_string());
996
997 // Generate `extra_params` for function signature
998 // and `extra_args` for computation call if
999 // extra arguments to capture runtime state are passed.
1000 // (look at polygamma for example).
1001 std::string extra_params = "";
1002 std::string extra_args = "";
1003 for (size_t i = 0; i < extra_args_typenames.size(); i++) {
1004 auto type = std::string(extra_args_typenames[i]);
1005 auto name = "extra_arg_" + std::to_string(i);
1006 extra_params += "," + type + " " + name;
1007 extra_args += ", " + name;
1008 }
1009 env.s("extra_params", extra_params);
1010 env.s("extra_args", extra_args);
1011
1012 std::stringstream declare_load_arrays;
1013 for (int i = 0; i < nInputs; i++) {
1014 // TODO these arrays are potentially of the different types, use function
1015 // traits to determine the types
1016 declare_load_arrays << f_inputs_type << " arg" << std::to_string(i)
1017 << "[" << std::to_string(thread_work_size) << "];\n";
1018 }
1019 env.s("declare_load_arrays", declare_load_arrays.str());
1020
1021 std::stringstream declare_store_arrays;
1022 for (int i = 0; i < nOutputs; i++) {
1023 declare_store_arrays << result_type << " out" << std::to_string(i)
1024 << "[" << std::to_string(thread_work_size) << "];\n";
1025 }
1026 env.s("declare_store_arrays", declare_store_arrays.str());
1027
1028 std::stringstream functor_args;
1029 if (scalar_pos == BinaryFuncVariant::NoScalar) {
1030 for (int i = 0; i < nInputs - 1; i++) {
1031 functor_args << "arg" << std::to_string(i) << "[j], ";
1032 }
1033 functor_args << "arg" << std::to_string(nInputs - 1) << "[j]";
1034 } else if (scalar_pos == BinaryFuncVariant::LhsScalar) {
1035 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(nInputs == 1);
1036 functor_args << "scalar_val, arg0[j]";
1037 } else { //RhsScalar
1038 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(nInputs == 1);
1039 functor_args << "arg0[j], scalar_val";
1040 }
1041 env.s("args", functor_args.str());
1042
1043 std::string call_functor_template;
1044 if (return_by_ref) { // return one or more outputs by reference
1045 bool need_temp_out = (compute_type != result_type);
1046 std::stringstream functor_outs;
1047 if (need_temp_out) {
1048 for (int i = 0; i < nOutputs - 1; i++) {
1049 functor_outs << "temp_out" << std::to_string(i) << ", ";
1050 }
1051 functor_outs << "temp_out" << std::to_string(nOutputs - 1);
1052 } else {
1053 for (int i = 0; i < nOutputs - 1; i++) {
1054 functor_outs << "out" << std::to_string(i) << "[j], ";
1055 }
1056 functor_outs << "out" << std::to_string(nOutputs - 1) << "[j]";
1057 }
1058 env.s("functor_outs", functor_outs.str());
1059
1060 if (need_temp_out) {
1061 call_functor_template += "${compute_type} ${functor_outs};\n";
1062 }
1063
1064 call_functor_template += "${name}<${compute_type}>(${args} ${extra_args}, ${functor_outs});\n";
1065
1066 if (need_temp_out) {
1067 for (int i = 0; i < nOutputs; i++) {
1068 auto i_string = std::to_string(i);
1069 call_functor_template += "out" +i_string + "[j] = temp_out" + i_string + ";\n";
1070 }
1071 }
1072
1073 } else { // return by value for single output functor
1074 call_functor_template = "out0[j] = ${name}<${compute_type}>(${args} ${extra_args});";
1075 }
1076 env.s("call_functor", at::jit::CodeTemplate(call_functor_template).format(env));
1077
1078 if (f_inputs_type == "at::Half" || result_type == "at::Half" ||
1079 f_inputs_type == "std::complex<at::Half>" ||
1080 result_type == "std::complex<at::Half>" || dynamic_casting) {
1081 // complex<Half> depends on complex<T> and Half dtypes.
1082 env.s("half_string", jiterator_half_support_literal);
1083 } else {
1084 env.s("half_string", "");
1085 }
1086 if (f_inputs_type == "at::BFloat16" || result_type == "at::BFloat16" || dynamic_casting) {
1087 env.s("bfloat16_string", jiterator_bfloat16_support_literal);
1088 } else {
1089 env.s("bfloat16_string", "");
1090 }
1091 // the definition of complex math functions is only needed when the compute type is complex
1092 // but the definition of std::complex is needed for dynamic casting even if the compute type is not complex
1093 if (f_inputs_type == "std::complex<float>" || result_type == "std::complex<float>" ||
1094 f_inputs_type == "std::complex<double>" || result_type == "std::complex<double>" ||
1095 f_inputs_type == "std::complex<at::Half>" || result_type == "std::complex<at::Half>") {
1096 // complex<Half> depends on complex<T> and Half dtypes.
1097 env.s("traits_string", get_traits_string_but_hiprtc_safe());
1098 env.s("complex_body_string", get_complex_body_string());
1099 env.s("complex_math_string", get_complex_math_string());
1100 #ifdef USE_ROCM
1101 // unhipify math functions, but only if std::complex is used.
1102 func = unhipify_math_functions(func);
1103 env.s("functor", func);
1104 #endif
1105 } else if (dynamic_casting) {
1106 env.s("traits_string", get_traits_string_but_hiprtc_safe());
1107 env.s("complex_body_string", get_complex_body_string());
1108 env.s("complex_math_string", "");
1109 } else {
1110 env.s("traits_string", "");
1111 env.s("complex_body_string", "");
1112 env.s("complex_math_string", "");
1113 }
1114 if (f_inputs_type == "std::complex<at::Half>" ||
1115 result_type == "std::complex<at::Half>" || dynamic_casting) {
1116 // dynamic_casting requires the definition of all types
1117 // include complex<at::Half>
1118 // Look at the definition of `StoreWithCast` and `LoadWithCast`.
1119 env.s("complex_half_body_string", get_complex_half_body_string());
1120 } else {
1121 env.s("complex_half_body_string", "");
1122 }
1123
1124 env.s("load_support", load_support_literal);
1125
1126 if (!vectorized) {
1127 if (!dynamic_casting) {
1128 env.s("loader", "LoadWithoutCast");
1129 env.s("storer", "StoreWithoutCast");
1130 env.s("dynamic_casting_string", no_dynamic_cast_support_literal);
1131 } else {
1132 env.s("loader", std::string("LoadWithCast<" + std::to_string(nInputs) + ">"));
1133 env.s("storer", std::string("StoreWithCast<" + std::to_string(nOutputs) + ">"));
1134 env.s("dynamic_casting_string", dynamic_cast_support_literal);
1135 }
1136
1137 if (contiguous) {
1138 env.s("offset_calculator", "TrivialOffsetCalculator");
1139 } else {
1140 env.s("offset_calculator", "OffsetCalculator");
1141 }
1142
1143 std::stringstream load_inputs;
1144 for (int i = 0; i < nInputs; i++) {
1145 auto i_string = std::to_string(i);
1146 load_inputs << "arg" << i_string << "[j] = l.load<" << f_inputs_type
1147 << ">(data[" << std::to_string(i + nOutputs)
1148 << "], input_offsets[" << i_string << "], " << i_string
1149 << ");\n";
1150 }
1151 env.s("load_inputs", load_inputs.str());
1152
1153 std::stringstream store_outputs;
1154 for (int i = 0; i < nOutputs; i++) {
1155 auto i_string = std::to_string(i);
1156 store_outputs << "s.store<" << result_type
1157 << ">(out" << i_string << "[j], data[" << i_string
1158 << "], output_offsets[" << i_string << "], " << i_string
1159 << ");\n";
1160 }
1161 env.s("store_outputs", store_outputs.str());
1162
1163 static auto cuda_template = at::jit::CodeTemplate(
1164 jit_preamble + jit_common_types + offset_calc_template + jit_code_template + jit_epilogue);
1165 const auto code = cuda_template.format(env);
1166 return code;
1167 }
1168
1169 // vectorized case
1170 env.s("vec_size", std::to_string(vec_size));
1171 env.s("result_type", result_type);
1172
1173 std::stringstream vector_inputs;
1174 for (const auto i : c10::irange(nInputs)){
1175 auto i_string = std::to_string(i);
1176 vector_inputs << "auto * input" << i_string <<
1177 " = reinterpret_cast<const scalar_t*>(data[" << i_string << "+" << nOutputs << "])" <<
1178 " + block_work_size * idx;\n";
1179 }
1180 env.s("vector_inputs", vector_inputs.str());
1181
1182 std::stringstream vector_outputs;
1183 for (const auto i : c10::irange(nOutputs)){
1184 auto i_string = std::to_string(i);
1185 vector_outputs << "vec_t_output* to_" << i_string <<
1186 " = reinterpret_cast<vec_t_output*>(data[" << i_string << "])" <<
1187 " + block_work_size / vec_size * idx;\n";
1188 }
1189 env.s("vector_outputs", vector_outputs.str());
1190
1191 std::stringstream load_vectorized_inputs;
1192 for (const auto i : c10::irange(nInputs)) {
1193 auto i_string = std::to_string(i);
1194 load_vectorized_inputs << "const auto vec" << i_string << " = load_vector<vec_size>("
1195 << "input" << i_string << ", thread_idx);\n";
1196 load_vectorized_inputs << "#pragma unroll\n";
1197 load_vectorized_inputs << "for (int j=0; j < vec_size; j++){\n";
1198 load_vectorized_inputs << " arg" << i_string << "[vec_size * i + j] = vec" << i_string << ".val[j];\n";
1199 load_vectorized_inputs << "}\n";
1200 }
1201 env.s("load_vectorized_inputs", load_vectorized_inputs.str());
1202
1203 std::stringstream store_vectorized_outputs;
1204 for (const auto i : c10::irange(nOutputs)) {
1205 auto i_string = std::to_string(i);
1206 store_vectorized_outputs << "#pragma unroll\n";
1207 store_vectorized_outputs << "for (int j=0; j<vec_size; j++){\n";
1208 store_vectorized_outputs << "v.val[j] = out" << i_string << "[vec_size * i + j];\n";
1209 store_vectorized_outputs << "}\n";
1210 store_vectorized_outputs << "to_"<< i_string << "[thread_idx] = v;\n";
1211 }
1212 env.s("store_vectorized_outputs", store_vectorized_outputs.str());
1213
1214 std::stringstream load_unrolled_inputs;
1215 for (const auto i: c10::irange(nInputs)){
1216 auto i_string = std::to_string(i);
1217 load_unrolled_inputs << "arg" << i_string << "[j] = load<" << f_inputs_type
1218 << ">(data[" << std::to_string(i + nOutputs) << "], linear_idx);\n";
1219 }
1220 env.s("load_unrolled_inputs", load_unrolled_inputs.str());
1221
1222 std::stringstream store_unrolled_outputs;
1223 for (const auto i : c10::irange(nOutputs)) {
1224 auto i_string = std::to_string(i);
1225 store_unrolled_outputs << "store<" << result_type << ">(out" << i_string
1226 << "[j], data[" << i_string << "], linear_idx);\n";
1227 }
1228 env.s("store_unrolled_outputs", store_unrolled_outputs.str());
1229
1230 static auto cuda_template = at::jit::CodeTemplate(
1231 jit_preamble + jit_common_types + jit_vectorized_code_template + jit_epilogue);
1232 const auto code = cuda_template.format(env);
1233 return code;
1234 }
1235
1236 // Creates directories recursively
_r_mkdir(const std::string & dir)1237 bool _r_mkdir(const std::string& dir) {
1238 // Check if current dir exists
1239 const char* p_dir = dir.c_str();
1240 const bool dir_exists = (access(p_dir, F_OK) == 0);
1241 if (dir_exists) {
1242 return true;
1243 }
1244
1245 // Try to create current directory
1246 #ifdef _WIN32
1247 int ret = _mkdir(dir.c_str());
1248 #else
1249 int ret = mkdir(dir.c_str(), S_IRWXU | S_IRWXG | S_IRWXO);
1250 #endif
1251 // Success
1252 if (ret == 0) {
1253 return true;
1254 }
1255
1256 // Find folder separator and check if we are at the top
1257 auto pos = dir.find_last_of("/\\");
1258 if (pos == std::string::npos) {
1259 return false;
1260 }
1261
1262 // Try to create parent directory
1263 if (!(_r_mkdir(dir.substr(0, pos)))) {
1264 return false;
1265 }
1266
1267 // Try to create complete path again
1268 #ifdef _WIN32
1269 ret = _mkdir(dir.c_str());
1270 #else
1271 ret = mkdir(dir.c_str(), S_IRWXU | S_IRWXG | S_IRWXO);
1272 #endif
1273 return ret == 0;
1274 }
1275
1276 // Creates directories recursively assuming that base exists
r_mkdir_with_base(std::string & base,std::string & dir)1277 bool r_mkdir_with_base(std::string& base, std::string& dir){
1278 const char* p_base = base.c_str();
1279 const bool base_exists = (access(p_base, F_OK) == 0);
1280 if (!base_exists) {
1281 return false;
1282 }
1283
1284 // remove trailing '/' or '\\'
1285 if ((base[base.size()-1]=='/') || base[base.size()-1]=='\\') {
1286 base.pop_back();
1287 }
1288 if ((dir[dir.size()-1]=='/') || dir[dir.size()-1]=='\\') {
1289 dir.pop_back();
1290 }
1291
1292 return _r_mkdir(base+dir);
1293
1294 }
1295
load_code_template(const std::string & path)1296 std::string load_code_template(const std::string& path) {
1297 std::ifstream ifs{path};
1298 std::string s{
1299 std::istreambuf_iterator<char>(ifs),
1300 std::istreambuf_iterator<char>()};
1301 return s;
1302 }
1303
generate_reduction_code(const KernelDescriptor & desc,int vt0,bool contiguous,bool vectorized,int vec_size,int max_threads_codegen)1304 std::string generate_reduction_code(
1305 const KernelDescriptor &desc,
1306 int vt0,
1307 bool contiguous,
1308 bool vectorized,
1309 int vec_size,
1310 int max_threads_codegen) {
1311 TORCH_INTERNAL_ASSERT(desc.nInputs == 1);
1312 TORCH_INTERNAL_ASSERT(desc.extra_args_types.size() == 0);
1313
1314 return generate_reduction_code(
1315 desc.nOutputs,
1316 desc.f,
1317 desc.name,
1318 vt0,
1319 typeName(desc.f_inputs_type),
1320 typeName(toOpMathType(desc.f_inputs_type)),
1321 typeName(desc.result_type),
1322 contiguous,
1323 vectorized,
1324 vec_size,
1325 max_threads_codegen
1326 );
1327 }
1328
generate_reduction_code(int nOutputs,const std::string & func_,const std::string & name,const int vt0,const std::string & f_inputs_type,const std::string & reduction_accum_type,const std::string & result_type,bool contiguous,bool vectorized,int vec_size,int max_threads_codegen)1329 std::string generate_reduction_code(
1330 int nOutputs,
1331 const std::string& func_,
1332 const std::string& name,
1333 const int vt0,
1334 const std::string& f_inputs_type,
1335 const std::string& reduction_accum_type,
1336 const std::string& result_type,
1337 bool contiguous,
1338 bool vectorized,
1339 int vec_size,
1340 int max_threads_codegen) {
1341 std::string func = func_;
1342 at::jit::TemplateEnv env;
1343 env.s("index_type", "unsigned int");
1344 env.s("scalar_type", f_inputs_type);
1345 env.s("result_type", result_type);
1346 env.s("reduction_accum_type", reduction_accum_type);
1347 env.s("vt0", std::to_string(vt0));
1348 env.s("name", name);
1349 env.s("max_threads_lb", std::to_string(max_threads_codegen));
1350 // reductions don't support dynamic casting, so the only way to get nonstandard types
1351 // is through input
1352 if (f_inputs_type == "at::Half" || f_inputs_type == "std::complex<at::Half>") {
1353 // complex<Half> depends on complex<T> and Half dtypes.
1354 env.s("half_string", jiterator_half_support_literal);
1355 } else {
1356 env.s("half_string", "");
1357 }
1358 if (f_inputs_type == "at::BFloat16") {
1359 env.s("bfloat16_string", jiterator_bfloat16_support_literal);
1360 } else {
1361 env.s("bfloat16_string", "");
1362 }
1363 if (f_inputs_type == "std::complex<float>" ||
1364 f_inputs_type == "std::complex<double>" ||
1365 f_inputs_type == "std::complex<at::Half>" ) {
1366 // complex<Half> depends on complex<T> and Half dtypes.
1367 env.s("traits_string", get_traits_string_but_hiprtc_safe());
1368 env.s("complex_body_string", get_complex_body_string());
1369 env.s("complex_math_string", get_complex_math_string());
1370 env.s("complex", std::to_string(1));
1371 #ifdef USE_ROCM
1372 // unhipify math functions, but only if std::complex is used.
1373 func = unhipify_math_functions(func);
1374 #endif
1375 } else {
1376 env.s("traits_string", "");
1377 env.s("complex_body_string", "");
1378 env.s("complex_math_string", "");
1379 env.s("complex", std::to_string(0));
1380 }
1381 if (f_inputs_type == "std::complex<at::Half>") {
1382 env.s("complex_half_body_string", get_complex_half_body_string());
1383 } else {
1384 env.s("complex_half_body_string", "");
1385 }
1386 env.s("cmath_string", get_cmath_string());
1387 env.s("functor", func);
1388 env.s("output_vec_size", std::to_string(vec_size));
1389 static auto cuda_template = at::jit::CodeTemplate(
1390 jit_preamble + jit_common_types + offset_calc_template + get_reduction_template() + jit_epilogue);
1391 const auto code = cuda_template.format(env);
1392 return code;
1393 }
1394
1395 // Acquires (possibly creating) the kernel cache directory
get_cache_dir()1396 std::optional<std::string> get_cache_dir() {
1397 // If the environment variable USE_TORCH_KERNEL_CACHE is set to "0" then no persistent cache is used
1398 const char* uptkc = std::getenv("USE_PYTORCH_KERNEL_CACHE");
1399 const bool use_kernel_cache = (uptkc == nullptr) ? true : std::strcmp(uptkc, "0");
1400
1401 if (!use_kernel_cache) {
1402 return {};
1403 }
1404
1405 // Cache path comes from PYTORCH_KERNEL_CACHE_PATH, then TEMP (Windows) or XDG_CACHE_HOME (Linux), then HOME environment variables
1406 std::string cache_dir;
1407 char* ptkcp = std::getenv("PYTORCH_KERNEL_CACHE_PATH");
1408 // Create kernel_cache_dir if needed as we do not want to create the base directory passed by the user
1409 std::string kernels_cache_dir = "";
1410 if (ptkcp != nullptr) {
1411 cache_dir = std::string(ptkcp);
1412 } else {
1413 #ifdef _WIN32
1414 ptkcp = std::getenv("TEMP");
1415 #else
1416 // USES XDG_CACHE_HOME if it's set
1417 ptkcp = std::getenv("XDG_CACHE_HOME");
1418 #endif
1419 if (ptkcp != nullptr) {
1420 kernels_cache_dir = "/torch/kernels";
1421 cache_dir = std::string(ptkcp) + kernels_cache_dir;
1422 } else {
1423 // Falls back to HOME/.cache
1424 ptkcp = std::getenv("HOME");
1425 if (ptkcp == nullptr) {
1426 TORCH_WARN_ONCE("No PYTORCH_KERNEL_CACHE_PATH or HOME environment variable set!",
1427 " This disables kernel caching.");
1428 return {};
1429 } else {
1430 kernels_cache_dir = "/.cache/torch/kernels";
1431 cache_dir = std::string(ptkcp) + kernels_cache_dir;
1432 }
1433 }
1434 }
1435
1436 // Creates the cache directory if it does not exist
1437 const char* p_cache_dir = cache_dir.c_str();
1438 const bool cache_dir_exists = (access(p_cache_dir, F_OK) == 0);
1439 if (!cache_dir_exists) {
1440 std::string s_ptkcp = std::string(ptkcp);
1441 if (!r_mkdir_with_base(s_ptkcp, kernels_cache_dir)) {
1442 TORCH_WARN_ONCE("Specified kernel cache directory could not be created! This disables kernel caching.",
1443 " Specified directory is ", cache_dir, ".",
1444 " This warning will appear only once per process.");
1445 return {};
1446 }
1447 }
1448
1449 // Checks that the cache directory is readable and writable
1450 const bool cache_dir_readable = (access(p_cache_dir, R_OK) == 0);
1451 if (!cache_dir_readable) {
1452 TORCH_WARN_ONCE("Specified kernel cache directory is not readable! This disables kernel caching.",
1453 " Specified directory is ", cache_dir, ".",
1454 " This warning will appear only once per process.");
1455 return {};
1456 }
1457
1458 const bool cache_dir_writable = (access(p_cache_dir, W_OK) == 0);
1459 if (!cache_dir_writable) {
1460 TORCH_WARN_ONCE("Specified kernel cache directory is not writable! This disables kernel caching.",
1461 " Specified directory is ", cache_dir, ".",
1462 " This warning will appear only once per process.");
1463 return {};
1464 }
1465
1466 return cache_dir;
1467 }
1468
1469 // Compiles the kernel, or acquires if from the cache if caching
jit_pwise_function(const std::string & code,const std::string & kernel_name)1470 NvrtcFunction jit_pwise_function(
1471 const std::string& code,
1472 const std::string& kernel_name) {
1473 initializeCudaContext();
1474 // Acquires CUDA and nvrtc versions and whether we're compiling to ptx or SASS
1475 const cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
1476 int cuda_major = 0, cuda_minor = 0, nvrtc_major = 0, nvrtc_minor = 0;
1477 bool compile_to_sass = false;
1478 at::cuda::jit::codegenOutputQuery(
1479 prop, cuda_major, cuda_minor, nvrtc_major, nvrtc_minor, compile_to_sass);
1480
1481 // Objects used whether loading from the cache or jit compiling
1482 const auto& nvrtc = at::globalContext().getNVRTC();
1483 NvrtcFunction compiled_kernel_;
1484 std::string name = kernel_name + "_kernel";
1485
1486 static const std::optional<std::string> cache_dir = get_cache_dir();
1487
1488 std::string file_path;
1489 if (cache_dir.has_value()) {
1490 // Attemps to read from the cache.
1491 // Cubin name is <kernel name>_arch<major>.<minor>_nvrtc<major>.<minor>_<ptx or sass>_<program length>_<string hash>
1492 // Note that the SHA1 hash used in the file name is NOT the SHA1 hash of the file's contents,
1493 // because we hash on the CUDA code, but we save the compiled ptx or sass
1494
1495 // Acquires SHA1 hash
1496 c10::sha1 sha1_hash{code};
1497 const auto hash_code = sha1_hash.str();
1498
1499 // Constructs file path by appending constructed cubin name to cache path
1500 std::stringstream ss;
1501 ss << *cache_dir << "/";
1502 ss << kernel_name;
1503 #ifdef USE_ROCM
1504 ss << "_arch" << prop->gcnArchName;
1505 #else
1506 ss << "_arch" << cuda_major << "." << cuda_minor;
1507 #endif
1508 ss << "_nvrtc" << nvrtc_major << "." << nvrtc_minor;
1509 ss << (compile_to_sass ? "_sass" : "_ptx");
1510 ss << "_" << code.length();
1511 ss << "_" << hash_code;
1512 file_path = ss.str();
1513
1514 std::ifstream readin{file_path, std::ios::in | std::ifstream::binary};
1515 if (readin.fail()) {
1516 // NOTE: this does not warn because the file might not exist
1517 // TODO: consider if this should explicitly check for the file's existence or not to throw
1518 // an informative warning
1519 readin.close();
1520 } else {
1521 // TODO: try passing the "mapped" file directly to cuModuleLoadCall instead of using an intermediate buffer
1522 std::vector<char> buffer(std::istreambuf_iterator<char>(readin), {});
1523 AT_CUDA_DRIVER_CHECK(nvrtc.cuModuleLoadData(&(compiled_kernel_.module), buffer.data()));
1524 AT_CUDA_DRIVER_CHECK(
1525 nvrtc.cuModuleGetFunction(&(compiled_kernel_.function), compiled_kernel_.module, name.c_str()));
1526 readin.close();
1527 return compiled_kernel_;
1528 }
1529 }
1530
1531 // Just-in-time compiles the program
1532
1533 // Creates the NVRTC program
1534 nvrtcProgram program;
1535 AT_CUDA_NVRTC_CHECK(nvrtc.nvrtcCreateProgram(
1536 &program, code.c_str(), nullptr, 0, nullptr, nullptr));
1537
1538 #ifdef USE_ROCM
1539 std::vector<const char*> args = {"--std=c++17"};
1540 #else
1541 // Constructs nvrtc build arguments
1542 // CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_)
1543 // which gives better backwards compatibility to work on older driver,
1544 // (since older driver doesn't necessarily recognize PTX emitted by new
1545 // toolkit);
1546 // Meanwhile, for forward compatibility (future device with
1547 // `unsupported_arch==True`), since SASS are not necessarily compatible,
1548 // we fallback to PTX instead.
1549 const std::string compute = std::string("--gpu-architecture=") +
1550 (compile_to_sass ? "sm_" : "compute_") + std::to_string(cuda_major) +
1551 std::to_string(cuda_minor);
1552 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1553 std::vector<const char*> args = {
1554 "--std=c++17", compute.c_str(), "-default-device"};
1555 #endif
1556
1557 #ifndef NDEBUG
1558 // Add line info to generated kernels
1559 args.push_back("-lineinfo");
1560 #else
1561 // Avoid excessive register usage from assertion
1562 args.push_back("-DNDEBUG");
1563 #endif
1564
1565 const auto compilation_result =
1566 nvrtc.nvrtcCompileProgram(program, args.size(), args.data());
1567
1568 // Throws an error on compilation failure
1569 if (compilation_result != NVRTC_SUCCESS) {
1570 size_t logsize;
1571 AT_CUDA_NVRTC_CHECK(nvrtc.nvrtcGetProgramLogSize(program, &logsize));
1572 std::string log(logsize, '\0');
1573 AT_CUDA_NVRTC_CHECK(nvrtc.nvrtcGetProgramLog(program, &log[0]));
1574 throw std::runtime_error(code + log);
1575 }
1576
1577 size_t ptx_size = 0;
1578 std::vector<char> ptx;
1579 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
1580 // compile_to_sass determines whether we are generating SASS or PTX, hence
1581 // the different API.
1582 const auto getSize = compile_to_sass
1583 ? at::globalContext().getNVRTC().nvrtcGetCUBINSize
1584 : at::globalContext().getNVRTC().nvrtcGetPTXSize;
1585 const auto getFunc = compile_to_sass
1586 ? at::globalContext().getNVRTC().nvrtcGetCUBIN
1587 : at::globalContext().getNVRTC().nvrtcGetPTX;
1588 #else
1589 const auto getSize = at::globalContext().getNVRTC().nvrtcGetPTXSize;
1590 const auto getFunc = at::globalContext().getNVRTC().nvrtcGetPTX;
1591 #endif
1592
1593 AT_CUDA_NVRTC_CHECK(getSize(program, &ptx_size));
1594 ptx.resize(ptx_size);
1595 AT_CUDA_NVRTC_CHECK(getFunc(program, ptx.data()));
1596
1597 AT_CUDA_DRIVER_CHECK(nvrtc.cuModuleLoadData(&(compiled_kernel_.module), ptx.data()));
1598
1599 AT_CUDA_DRIVER_CHECK(
1600 nvrtc.cuModuleGetFunction(&(compiled_kernel_.function), compiled_kernel_.module, name.c_str()));
1601 // TODO: use guards to avoid leaking
1602 AT_CUDA_NVRTC_CHECK(nvrtc.nvrtcDestroyProgram(&program));
1603
1604 if (cache_dir.has_value()) {
1605 // Writes the program to the cache if caching
1606 // NOTE: Actually writes to a per-process temporary file to avoid multi-process contention.
1607 // The temporary file is then renamed to the actual file.
1608 // If the actual file already exists then the rename may fail or replace the actual file,
1609 // the behavior is implementation-specific.
1610 // Files replaced through this process should remain extant if they are being read because
1611 // of UNIX filesystem properties, but this behavior is unverified and may require
1612 // additional review in the future.
1613 // TODO: In C++17 we should be able to use the filesystem header.
1614 const auto pid = getpid();
1615 std::stringstream tmp_file_path_ss;
1616 tmp_file_path_ss << file_path << "_tmp_" << pid;
1617 const std::string tmp_file_path = tmp_file_path_ss.str();
1618 std::ofstream cubin(tmp_file_path, std::ios::out | std::ofstream::binary);
1619 if (cubin.fail()) {
1620 TORCH_WARN_ONCE("Failed to write temporarily kernel cache file!",
1621 " File path was ", tmp_file_path, ".",
1622 " This warning will only appear once per process.");
1623 } else {
1624 std::copy(ptx.begin(), ptx.end(), std::ostreambuf_iterator<char>(cubin));
1625 if (std::rename(tmp_file_path.c_str(), file_path.c_str()) != 0) {
1626 // Removes tmp file if the rename failed
1627 std::remove(tmp_file_path.c_str());
1628 }
1629 }
1630 cubin.close();
1631 }
1632
1633 return compiled_kernel_;
1634 }
1635
1636 // TODO: may need/want to initialize CUDA context here (refactor into nvrtc call)
launch_jitted_pwise_function(NvrtcFunction function,void * args[],const dim3 nBlocks,const dim3 kBlockSize,const int smem)1637 void launch_jitted_pwise_function(
1638 NvrtcFunction function,
1639 void* args[],
1640 const dim3 nBlocks,
1641 const dim3 kBlockSize,
1642 const int smem) {
1643 initializeCudaContext();
1644 const auto& nvrtc = at::globalContext().getNVRTC();
1645 // Launches kernel on current stream
1646 auto stream = at::cuda::getCurrentCUDAStream();
1647 AT_CUDA_DRIVER_CHECK(nvrtc.cuLaunchKernel(
1648 function.function,
1649 nBlocks.x,
1650 nBlocks.y,
1651 nBlocks.z,
1652 kBlockSize.x,
1653 kBlockSize.y,
1654 kBlockSize.z,
1655 smem,
1656 stream,
1657 args,
1658 nullptr));
1659 }
1660
1661 } // at::cuda::jit
1662