xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/jit_utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <c10/core/ScalarType.h>
3 #include <c10/util/irange.h>
4 #include <c10/util/hash.h>
5 #include <optional>
6 #include <ATen/jit_macros.h>
7 #include <ATen/cuda/CUDAContext.h>
8 #include <ATen/cuda/detail/OffsetCalculator.cuh>
9 #include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
10 #include <ATen/code_template.h>
11 #include <ATen/OpMathType.h>
12 #include <ATen/native/cuda/jit_utils.h>
13 #include <ATen/cuda/llvm_jit_strings.h>
14 #include <ATen/native/cuda/reduction_template.cuh>
15 
16 #include <sstream>
17 #include <fstream>
18 #include <cstdio>
19 #include <iterator> // istreambuf_iterator
20 #include <cstdlib>
21 #include <string>
22 
23 // TODO: C++17 has the filesystem header, which may replace these
24 #ifdef _WIN32
25   // On Windows, the POSIX implementations are considered deprecated. We simply map to the newer variant.
26   #include <process.h>
27   #include <direct.h>
28   #include <io.h>
29   #define access _access
30   #define getpid _getpid
31   #define R_OK    4
32   #define W_OK    2
33   #define F_OK    0
34 #else
35   #include <sys/types.h>
36   #include <sys/stat.h> // mkdir
37   #include <unistd.h>
38 #endif
39 
40 
41 namespace at::cuda::jit {
42 
43 // hiprtc already includes some traits, so this removes duplicate definitions of
44 // integral_constant, is_same, is_integral, enable_if, is_floating_point, is_arithmetic.
45 // Copied from aten/src/ATen/cuda/llvm_basic.cpp, then modified as above.
46 // If not compiling for ROCm, return the original get_traits_string().
get_traits_string_but_hiprtc_safe()47 std::string get_traits_string_but_hiprtc_safe() {
48 #ifdef USE_ROCM
49     return R"ESCAPE(
50 namespace std {
51 
52 template <class _Tp>
53 _Tp&& __declval(int);
54 template <class _Tp>
55 _Tp __declval(long);
56 template <class _Tp>
57 decltype(__declval<_Tp>(0)) declval() noexcept;
58 
59 template <class _Tp> struct remove_const            {typedef _Tp type;};
60 template <class _Tp> struct remove_const<const _Tp> {typedef _Tp type;};
61 template <class _Tp> using remove_const_t = typename remove_const<_Tp>::type;
62 
63 template <class _Tp> struct remove_volatile               {typedef _Tp type;};
64 template <class _Tp> struct remove_volatile<volatile _Tp> {typedef _Tp type;};
65 template <class _Tp> using remove_volatile_t = typename remove_volatile<_Tp>::type;
66 
67 template <class _Tp> struct remove_cv
68 {typedef typename remove_volatile<typename remove_const<_Tp>::type>::type type;};
69 template <class _Tp> using remove_cv_t = typename remove_cv<_Tp>::type;
70 
71 template <class _Tp> struct __libcpp_is_floating_point              : public false_type {};
72 template <>          struct __libcpp_is_floating_point<float>       : public true_type {};
73 template <>          struct __libcpp_is_floating_point<double>      : public true_type {};
74 template <>          struct __libcpp_is_floating_point<long double> : public true_type {};
75 
76 template <class _Tp>
77 inline constexpr bool is_arithmetic_v = is_arithmetic<_Tp>::value;
78 
79 template <class _Tp>
80 struct __numeric_type
81 {
82    static void __test(...);
83    static float __test(float);
84    static double __test(char);
85    static double __test(int);
86    static double __test(unsigned);
87    static double __test(long);
88    static double __test(unsigned long);
89    static double __test(long long);
90    static double __test(unsigned long long);
91    static double __test(double);
92    static long double __test(long double);
93 
94    typedef decltype(__test(declval<_Tp>())) type;
95    static const bool value = !is_same<type, void>::value;
96 };
97 
98 template <>
99 struct __numeric_type<void>
100 {
101    static const bool value = true;
102 };
103 
104 // __promote
105 
106 template <class _A1, class _A2 = void, class _A3 = void,
107           bool = __numeric_type<_A1>::value &&
108                  __numeric_type<_A2>::value &&
109                  __numeric_type<_A3>::value>
110 class __promote_imp
111 {
112 public:
113     static const bool value = false;
114 };
115 
116 template <class _A1, class _A2, class _A3>
117 class __promote_imp<_A1, _A2, _A3, true>
118 {
119 private:
120     typedef typename __promote_imp<_A1>::type __type1;
121     typedef typename __promote_imp<_A2>::type __type2;
122     typedef typename __promote_imp<_A3>::type __type3;
123 public:
124     typedef decltype(__type1() + __type2() + __type3()) type;
125     static const bool value = true;
126 };
127 
128 template <class _A1, class _A2>
129 class __promote_imp<_A1, _A2, void, true>
130 {
131 private:
132     typedef typename __promote_imp<_A1>::type __type1;
133     typedef typename __promote_imp<_A2>::type __type2;
134 public:
135     typedef decltype(__type1() + __type2()) type;
136     static const bool value = true;
137 };
138 
139 template <class _A1>
140 class __promote_imp<_A1, void, void, true>
141 {
142 public:
143     typedef typename __numeric_type<_A1>::type type;
144     static const bool value = true;
145 };
146 
147 template <class _A1, class _A2 = void, class _A3 = void>
148 class __promote : public __promote_imp<_A1, _A2, _A3> {};
149 
150 } // namespace std
151 )ESCAPE";
152 #else
153     return get_traits_string();
154 #endif
155 }
156 
157 #ifdef USE_ROCM
158 const std::string jit_preamble = R"ESCAPE(
159 #pragma clang force_cuda_host_device begin
160 )ESCAPE";
161 const std::string jit_epilogue = R"ESCAPE(
162 #pragma clang force_cuda_host_device end
163 )ESCAPE";
164 #else
165 const std::string jit_preamble;
166 const std::string jit_epilogue;
167 #endif
168 
169 const std::string jit_common_types = R"ESCAPE(
170   #ifdef __HIPCC__
171   #define ERROR_UNSUPPORTED_CAST ;
172   // corresponds to aten/src/ATen/native/cuda/thread_constants.h
173   #define CUDA_OR_ROCM_NUM_THREADS 256
174   // corresponds to aten/src/ATen/cuda/detail/OffsetCalculator.cuh
175   #define MAX_DIMS 16
176   #ifndef __forceinline__
177   #define __forceinline__ inline __attribute__((always_inline))
178   #endif
179   #else
180   //TODO use _assert_fail, because assert is disabled in non-debug builds
181   #define ERROR_UNSUPPORTED_CAST assert(false);
182   #define CUDA_OR_ROCM_NUM_THREADS 128
183   #define MAX_DIMS 25
184   #endif
185   #define POS_INFINITY __int_as_float(0x7f800000)
186   #define INFINITY POS_INFINITY
187   #define NEG_INFINITY __int_as_float(0xff800000)
188   #define NAN __int_as_float(0x7fffffff)
189 
190   typedef long long int int64_t;
191   typedef unsigned int uint32_t;
192   typedef signed char int8_t;
193   typedef unsigned char uint8_t;  // NOTE: this MUST be "unsigned char"! "char" is equivalent to "signed char"
194   typedef short int16_t;
195   static_assert(sizeof(int64_t) == 8, "expected size does not match");
196   static_assert(sizeof(uint32_t) == 4, "expected size does not match");
197   static_assert(sizeof(int8_t) == 1, "expected size does not match");
198   constexpr int num_threads = CUDA_OR_ROCM_NUM_THREADS;
199   constexpr int thread_work_size = 4; // TODO: make template substitution once we decide where those vars live
200   constexpr int block_work_size = thread_work_size * num_threads;
201 
202   ${traits_string}
203   ${cmath_string}
204 
205   // NB: Order matters for this macro; it is relied upon in
206   // _promoteTypesLookup and the serialization format.
207   // Note, some types have ctype as void because we don't support them in codegen
208   #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
209   _(uint8_t, Byte) /* 0 */                               \
210   _(int8_t, Char) /* 1 */                                \
211   _(int16_t, Short) /* 2 */                              \
212   _(int, Int) /* 3 */                                    \
213   _(int64_t, Long) /* 4 */                               \
214   _(at::Half, Half) /* 5 */                                  \
215   _(float, Float) /* 6 */                                \
216   _(double, Double) /* 7 */                              \
217   _(std::complex<at::Half>, ComplexHalf) /* 8 */        \
218   _(std::complex<float>, ComplexFloat) /* 9 */                          \
219   _(std::complex<double>, ComplexDouble) /* 10 */                         \
220   _(bool, Bool) /* 11 */                                 \
221   _(void, QInt8) /* 12 */                          \
222   _(void, QUInt8) /* 13 */                        \
223   _(void, QInt32) /* 14 */                        \
224   _(at::BFloat16, BFloat16) /* 15 */                             \
225 
226   #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_QINT(_)       \
227   _(uint8_t, Byte)                                                 \
228   _(int8_t, Char)                                                  \
229   _(int16_t, Short)                                                \
230   _(int, Int)                                                      \
231   _(int64_t, Long)                                                 \
232   _(at::Half, Half)                                                \
233   _(float, Float)                                                  \
234   _(double, Double)                                                \
235   _(std::complex<at::Half>, ComplexHalf)                           \
236   _(std::complex<float>, ComplexFloat)                             \
237   _(std::complex<double>, ComplexDouble)                           \
238   _(bool, Bool)                                                    \
239   _(at::BFloat16, BFloat16)
240 
241 
242   enum class ScalarType : int8_t {
243   #define DEFINE_ENUM(_1, n) n,
244   AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ENUM)
245   #undef DEFINE_ENUM
246       Undefined,
247   NumOptions
248   };
249 
250   template <typename T, int size>
251   struct Array {
252   T data[size];
253 
254   __device__ T operator[](int i) const {
255       return data[i];
256   }
257   __device__ T& operator[](int i) {
258       return data[i];
259   }
260   Array() = default;
261   Array(const Array&) = default;
262   Array& operator=(const Array&) = default;
263   __device__ Array(T x) {
264     for (int i = 0; i < size; i++) {
265       data[i] = x;
266     }
267   }
268   };
269 
270   ${half_string}
271   ${bfloat16_string}
272   ${complex_body_string}
273   ${complex_half_body_string}
274   ${complex_math_string}
275 
276 
277 )ESCAPE";
278 
279 //we need to include half, bfloat16 and complex strings to all kernels with half arguments and to all kernels with type casting
280 //regardless of whether they have half arguments (because fetch_and_cast and cast_and_store loop over all types)
281 const std::string jiterator_half_support_literal = R"ESCAPE(
282 namespace at {
283 struct alignas(2) Half {
284   unsigned short x;
285 
286   Half() = default;
287   inline __host__ __device__ Half(float value){
288 #ifdef __HIPCC__
289     x = __half_as_short(__float2half(value));
290 #else
291     asm("{  cvt.rn.f16.f32 %0, %1;}\n" : "=h"(x) : "f"(value));
292 #endif
293   }
294   inline __host__ __device__ operator float() const{
295 #ifdef __HIPCC__
296       return __half2float(*reinterpret_cast<const __half*>(&x));
297 #else
298       float val;
299       asm("{  cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(x)); // do we need const cast here?
300       //asm("{  cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(__HALF_TO_CUS(x)));
301       return val;
302 #endif
303   }
304 };
305 }
306 )ESCAPE";
307 
308 const std::string jiterator_bfloat16_support_literal = R"ESCAPE(
309 namespace at {
310 struct alignas(2) BFloat16 {
311   unsigned short x;
312 
313   __device__ unsigned short __internal_float2bfloat16(
314       const float f,
315       unsigned int& sign,
316       unsigned int& remainder) {
317     unsigned int x;
318 
319     x = __float_as_uint(f);
320 
321     if ((x & 0x7fffffffU) > 0x7f800000U) {
322       sign = 0U;
323       remainder = 0U;
324       return static_cast<unsigned short>(0x7fffU);
325     }
326     sign = x >> 31;
327     remainder = x << 16;
328     return static_cast<unsigned short>(x >> 16);
329   }
330 
331 
332   BFloat16() = default;
333   inline __host__ __device__ BFloat16(float value){
334   #if __CUDA_ARCH__ >= 800
335   asm("{  cvt.rn.bf16.f32 %0, %1;}\n" : "=h"(x) : "f"(value));
336   )ESCAPE"
337   R"ESCAPE(
338   #else
339   unsigned int sign;
340   unsigned int remainder;
341   x = __internal_float2bfloat16(value, sign, remainder);
342   if ((remainder > 0x80000000U) ||
343       ((remainder == 0x80000000U) && ((x & 0x1U) != 0U))) {
344     x++;
345   }
346   #endif
347   }
348 
349   inline __host__ __device__ operator float() const{
350 #ifdef __HIPCC__
351     union
352     {
353         uint32_t int32;
354         float    fp32;
355     } u = {uint32_t(x) << 16};
356     return u.fp32;
357 #else
358     float val;
359     asm("{ mov.b32 %0, {0,%1};}\n" : "=f"(val) : "h"(x)); //do we need const cast here?
360     return val;
361 #endif
362   }
363 
364 };
365 }
366 )ESCAPE";
367 
368 // From c10/util/Load.h
369 const std::string load_support_literal = R"ESCAPE(
370 
371   namespace c10 {
372     template <typename T>
373     struct LoadImpl {
374       __device__ static T apply(const void *src) {
375         return *reinterpret_cast<const T*>(src);
376       }
377     };
378 
379     template <>
380     struct LoadImpl<bool> {
381       __device__ static bool apply(const void *src) {
382         static_assert(sizeof(bool) == sizeof(char), "");
383         return LoadImpl<char>::apply(src);
384       }
385     };
386 
387     template <typename T>
388     __device__ T load(const void *src) {
389       return LoadImpl<T>::apply(src);
390     }
391 
392     template <typename scalar_t>
393     __device__ scalar_t load(const scalar_t *src) {
394       return LoadImpl<scalar_t>::apply(src);
395     }
396   }  // namespace c10
397 
398 )ESCAPE";
399 
400 // copy-pasted from c10/util/TypeCast.h and c10/core/DynamicCast.h
401 const std::string dynamic_cast_support_literal = R"ESCAPE(
402 
403   template <typename T>
404   struct is_complex : public std::false_type {};
405 
406   template <typename T>
407   struct is_complex<std::complex<T>> : public std::true_type {};
408 
409   template <typename dest_t, typename src_t>
410   struct needs_real {
411     constexpr static bool value =
412         (is_complex<src_t>::value && !is_complex<dest_t>::value);
413   };
414 
415   template <bool, typename src_t>
416   struct maybe_real {
417     static inline src_t apply(src_t src) {
418       return src;
419     }
420   };
421 
422   template <typename src_t>
423   struct maybe_real<true, src_t> {
424     static inline decltype(auto) apply(src_t src) {
425       return src.real();
426     }
427   };
428 
429   template <typename dest_t, typename src_t>
430   struct static_cast_with_inter_type {
431     static inline dest_t apply(
432         src_t src) {
433       constexpr bool real = needs_real<dest_t, src_t>::value;
434       return static_cast<dest_t>(maybe_real<real, src_t>::apply(src));
435     }
436   };
437 
438   template <typename src_t>
439   struct static_cast_with_inter_type<uint8_t, src_t> {
440     static inline uint8_t apply(
441         src_t src) {
442       constexpr bool real = needs_real<uint8_t, src_t>::value;
443       return static_cast<uint8_t>(
444           static_cast<int64_t>(maybe_real<real, src_t>::apply(src)));
445     }
446   };
447 
448   template <>
449   struct static_cast_with_inter_type<std::complex<at::Half>, at::BFloat16> {
450     static inline std::complex<at::Half> apply(at::BFloat16 src) {
451       return static_cast<std::complex<at::Half>>(float{src});
452     }
453   };
454 
455   template <>
456   struct static_cast_with_inter_type<std::complex<at::Half>, at::Half> {
457     static inline std::complex<at::Half> apply(at::Half src) {
458       return static_cast<std::complex<at::Half>>(float{src});
459     }
460   };
461 
462   template <>
463   struct static_cast_with_inter_type<
464       std::complex<at::Half>,
465       std::complex<double>> {
466     static inline std::complex<at::Half> apply(std::complex<double> src) {
467       return static_cast<std::complex<at::Half>>(static_cast<std::complex<float>>(src));
468     }
469   };
470 
471   // Fetch a value with dynamic type src_type from ptr, and cast it to static type dest_t.
472   #define FETCH_AND_CAST_CASE(type, scalartype) \
473     case ScalarType::scalartype:                \
474       return static_cast_with_inter_type<dest_t, type>::apply(c10::load<type>(ptr));
475   template<typename dest_t>
476   __device__ inline dest_t fetch_and_cast(const ScalarType src_type, const void *ptr) {
477     switch (src_type) {
478         AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_QINT(FETCH_AND_CAST_CASE)
479         default:
480           ERROR_UNSUPPORTED_CAST
481     }
482     return dest_t(0); // just to avoid compiler warning
483   }
484 
485   // Cast a value with static type src_t into dynamic dest_type, and store it to ptr.
486   #define CAST_AND_STORE_CASE(type, scalartype)                             \
487     case ScalarType::scalartype:                                            \
488       *(type*)ptr = static_cast_with_inter_type<type, src_t>::apply(value); \
489       return;
490   template<typename src_t>
491   __device__ inline void cast_and_store(const ScalarType dest_type, void *ptr, src_t value) {
492   switch (dest_type) {
493       AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_QINT(CAST_AND_STORE_CASE)
494       default:;
495   }
496   ERROR_UNSUPPORTED_CAST
497   }
498 
499   template <int N>
500   struct LoadWithCast {
501     using array_t = Array<ScalarType, N==0? 1 : N>;
502     using size_array_t = Array<uint32_t, N==0? 1: N>;
503 
504     array_t dtypes;
505     size_array_t element_sizes;
506     template <typename scalar_t>
507     __device__ scalar_t load(char* base_ptr, uint32_t offset, int arg) {
508         void* ptr = base_ptr + element_sizes[arg] * offset;
509         return fetch_and_cast<scalar_t>(dtypes[arg], ptr);
510     }
511   };
512 
513   template <int N = 1>
514   struct StoreWithCast {
515     using array_t = Array<ScalarType, N==0? 1 : N>;
516     using size_array_t = Array<uint32_t, N==0? 1: N>;
517 
518     array_t dtypes;
519     size_array_t element_sizes;
520 
521     template<typename scalar_t>
522     __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) {
523         void *ptr = base_ptr + element_sizes[arg] * offset;
524         cast_and_store<scalar_t>(dtypes[arg], ptr, value);
525     }
526   };
527 
528 )ESCAPE";
529 
530 const std::string no_dynamic_cast_support_literal = R"ESCAPE(
531 
532   struct LoadWithoutCast {
533   template <typename scalar_t>
534   __device__ scalar_t load(char* base_ptr, uint32_t offset, int arg=0) {
535     return c10::load(reinterpret_cast<scalar_t*>(base_ptr) + offset);
536   }
537   };
538 
539   struct StoreWithoutCast {
540   template<typename scalar_t>
541   __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg=0) {
542     *(reinterpret_cast<scalar_t *>(base_ptr) + offset) = value;
543   }
544   };
545 
546 )ESCAPE";
547 
548 const std::string offset_calc_template = R"ESCAPE(
549   template <typename T>
550   struct DivMod {
551   T div;
552   T mod;
553 
554   __device__ DivMod(T _div, T _mod) {
555       div = _div;
556       mod = _mod;
557   }
558   };
559 
560   //<unsigned int>
561   struct IntDivider {
562   IntDivider() = default;
563 
564   __device__ inline unsigned int div(unsigned int n) const {
565   unsigned int t = __umulhi(n, m1);
566   return (t + n) >> shift;
567   }
568 
569   __device__ inline unsigned int mod(unsigned int n) const {
570   return n - div(n) * divisor;
571   }
572 
573   __device__ inline DivMod<unsigned int> divmod(unsigned int n) const {
574   unsigned int q = div(n);
575   return DivMod<unsigned int>(q, n - q * divisor);
576   }
577 
578   unsigned int divisor;  // d above.
579   unsigned int m1;  // Magic number: m' above.
580   unsigned int shift;  // Shift amounts.
581   };
582 
583   template <int NARGS>
584   struct TrivialOffsetCalculator {
585     // The offset for each argument. Wrapper around fixed-size array.
586     // The offsets are in # of elements, not in bytes.
587     Array<${index_type}, NARGS> get(${index_type} linear_idx) const {
588       Array<${index_type}, NARGS> offsets;
589       #pragma unroll
590       for (int arg = 0; arg < NARGS; arg++) {
591         offsets[arg] = linear_idx;
592       }
593       return offsets;
594     }
595   };
596 
597   template<int NARGS>
598   struct OffsetCalculator {
599   OffsetCalculator() = default;
600   __device__ __forceinline__ Array<${index_type}, NARGS> get(${index_type} linear_idx) const {
601       Array<${index_type}, NARGS> offsets;
602       #pragma unroll
603       for (int arg = 0; arg < NARGS; ++arg) {
604       offsets[arg] = 0;
605       }
606 
607       #pragma unroll
608       for (int dim = 0; dim < MAX_DIMS; ++dim) {
609       if (dim == dims) {
610           break;
611       }
612 
613       auto divmod = sizes_[dim].divmod(linear_idx);
614       linear_idx = divmod.div;
615 
616       #pragma unroll
617       for (int arg = 0; arg < NARGS; ++arg) {
618           offsets[arg] += divmod.mod * strides_[dim][arg];
619       }
620       //printf("offset calc thread dim size stride offset %d %d %d %d %d %d %d %d\n",
621       //threadIdx.x, dim, sizes_[dim].divisor, strides_[dim][0], offsets[0], linear_idx, divmod.div, divmod.mod);
622       }
623       return offsets;
624   }
625 
626     int dims;
627     IntDivider sizes_[MAX_DIMS];
628     // NOTE: this approach will not support nInputs == 0
629     ${index_type} strides_[MAX_DIMS][NARGS];
630   };
631 
632 
633 )ESCAPE";
634 
635 const std::string jit_code_template = R"ESCAPE(
636 
637   ${load_support}
638   ${dynamic_casting_string}
639 
640 
641   ${functor}
642 
643   // TODO: setup grid-stride loop
644   extern "C" __global__
645   void ${name}_kernel(
646       const int numel,
647       Array<char*, ${nInputs}+${nOutputs}> data, //[${nInputs}+${nOutputs}],
648       ${offset_calculator}<${nInputs}> input_calculator,
649       ${offset_calculator}<${nOutputs}> output_calculator,
650       ${loader} l,
651       ${storer} s,
652       ${compute_type} scalar_val${extra_params}) {
653     ${declare_load_arrays}
654     ${declare_store_arrays}
655 
656     int idx = blockIdx.x;
657 
658     int remaining = numel - block_work_size * idx;
659     int thread_idx = threadIdx.x;
660 
661     #pragma unroll
662     for (int j = 0; j < thread_work_size; j++){
663         if (thread_idx >= remaining) {
664             break;
665         }
666 
667         int linear_idx = thread_idx + block_work_size * idx;
668         auto input_offsets = input_calculator.get(linear_idx);
669         ${load_inputs}
670         // printf(
671         //    "thread %d a %f offsets %d\n", threadIdx.x, arg0[j], input_offsets[0]);
672         thread_idx += num_threads;
673     }
674 
675     #pragma unroll
676     for (int j = 0; j < thread_work_size; j++) {
677       if ((threadIdx.x  + j*num_threads) < remaining) {
678         ${call_functor}
679       }
680     }
681 
682     thread_idx = threadIdx.x;
683     #pragma unroll
684     for (int j = 0; j < thread_work_size; j++){
685         if (thread_idx >= remaining) {
686             break;
687         }
688         //TODO maybe think about unifying offset calculators and reuse
689         //offsets computed in the load loop
690         int linear_idx = thread_idx + block_work_size * idx;
691         auto output_offsets = output_calculator.get(linear_idx);
692         //printf("output thread %d offset %d\n", threadIdx.x, output_offsets[0]);
693         ${store_outputs}
694         thread_idx += num_threads;
695     }
696   }
697 )ESCAPE";
698 
699 const std::string jit_vectorized_code_template = R"ESCAPE(
700 
701   ${load_support}
702 
703   template <typename scalar_t>
704   __device__ __inline__ scalar_t load(char* base_ptr, uint32_t offset) {
705       return c10::load(reinterpret_cast<scalar_t*>(base_ptr) + offset);
706   }
707 
708   template<typename scalar_t>
709   __device__ __inline__ void store(scalar_t value, char *base_ptr, uint32_t offset) {
710       *(reinterpret_cast<scalar_t *>(base_ptr) + offset) = value;
711   }
712 
713   // aligned vector generates vectorized load/store on CUDA
714   template<typename scalar_t, int vec_size>
715   struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
716     scalar_t val[vec_size];
717   };
718 
719   template <int vec_size, typename scalar_t>
720   __device__ aligned_vector<scalar_t, vec_size> load_vector(const scalar_t *base_ptr, uint32_t offset) {
721     using vec_t = aligned_vector<scalar_t, vec_size>;
722     auto *from = reinterpret_cast<const vec_t *>(base_ptr);
723     return from[offset];
724   }
725 
726   template <int vec_size>
727   __device__ aligned_vector<bool, vec_size> load_vector(const bool *base_ptr, uint32_t offset) {
728     // See NOTE [Loading boolean values]
729     auto tmp = load_vector<vec_size>(reinterpret_cast<const uint8_t*>(base_ptr), offset);
730     aligned_vector<bool, vec_size> ret;
731     for (int i = 0; i < vec_size; ++i) {
732       ret.val[i] = bool(tmp.val[i]);
733     }
734     return ret;
735   }
736 
737   ${functor}
738 
739   // TODO: setup grid-stride loop
740 
741   extern "C" __global__
742   void ${name}_vectorized${vec_size}_kernel(
743       const int N,
744       Array<char*, ${nInputs}+${nOutputs}> data,
745       ${compute_type} scalar_val${extra_params}) //[${nInputs}+${nOutputs}],
746       {
747       constexpr int vec_size = ${vec_size};
748       using scalar_t = ${scalar_type};
749       int remaining = N - block_work_size * blockIdx.x;
750       int thread_idx = threadIdx.x;
751       int idx = blockIdx.x;
752       ${declare_load_arrays}
753       ${declare_store_arrays}
754 
755       if (remaining < block_work_size) {
756         #pragma unroll
757         for (int j = 0; j < thread_work_size; j++){
758           if (thread_idx >= remaining) {
759             break;
760           }
761           int linear_idx = thread_idx + block_work_size * idx;
762           ${load_unrolled_inputs}
763           thread_idx += num_threads;
764         }
765         #pragma unroll
766         for (int j = 0; j < thread_work_size; j++) {
767           if ((threadIdx.x  + j*num_threads) < remaining) {
768             ${call_functor}
769           }
770         }
771         thread_idx = threadIdx.x;
772         #pragma unroll
773         for (int j = 0; j < thread_work_size; j++) {
774           if (thread_idx >= remaining) {
775               break;
776           }
777           int linear_idx = thread_idx + block_work_size * idx;
778           ${store_unrolled_outputs}
779           thread_idx += num_threads;
780         }
781       } else {
782         static constexpr int loop_size = thread_work_size / vec_size;
783   //actual loading
784         ${vector_inputs}
785         #pragma unroll
786         for (int i = 0; i<loop_size; i++){
787           ${load_vectorized_inputs}
788           thread_idx += num_threads;
789         }
790 
791         #pragma unroll
792         for (int j = 0; j < thread_work_size; j++) {
793           ${call_functor}
794         }
795 
796         using vec_t_output = aligned_vector<${result_type}, vec_size>;
797         ${vector_outputs}
798         int thread_idx = threadIdx.x;
799         #pragma unroll
800         for (int i = 0; i<loop_size; i++){
801           vec_t_output v;
802           ${store_vectorized_outputs}
803           thread_idx += num_threads;
804         }
805       }
806   }
807 )ESCAPE";
808 
replace_all(std::string & s,const std::string & to_replace,const std::string & replace_with)809 static void replace_all(std::string& s, const std::string& to_replace, const std::string& replace_with) {
810   std::ostringstream oss;
811   std::size_t pos = 0;
812   std::size_t prev_pos = pos;
813 
814   while (true) {
815     prev_pos = pos;
816     pos = s.find(to_replace, pos);
817     if (pos == std::string::npos)
818       break;
819     oss << s.substr(prev_pos, pos - prev_pos);
820     oss << replace_with;
821     pos += to_replace.size();
822   }
823 
824   oss << s.substr(prev_pos);
825   s = oss.str();
826 }
827 
828 // hipify replaces certain device math functions, e.g., std::max -> ::max
829 // See torch/utils/hipify/cuda_to_hip_mappings.py.
830 // Replace them back. Search for " ::<name>" to avoid duplicate replacements.
unhipify_math_functions(const std::string & original)831 static std::string unhipify_math_functions(const std::string &original) {
832   static std::vector<std::pair<std::string,std::string>> mappings = {
833     {" std::max", " ::max"},
834     {" std::min", " ::min"},
835     {" std::ceil", " ::ceil"},
836     {" std::floor", " ::floor"},
837     {" std::exp", " ::exp"},
838     {" std::log", " ::log"},
839     {" std::pow", " ::pow"},
840     {" std::fabs", " ::fabs"},
841     {" std::fmod", " ::fmod"},
842     {" std::remainder", " ::remainder"},
843     {" std::frexp", " ::frexp"}
844   };
845   std::string ret = original;
846   for (const auto& mapping : mappings) {
847     replace_all(ret, mapping.second, mapping.first);
848   }
849   return ret;
850 }
851 
852 // The following is copied from fused_kernel.cpp
853 // TODO: refactor codegenOutputQuery into its own file
854 //   that can be included by both files
855 // See NOTE [ USE OF NVRTC AND DRIVER API ]
nvrtc()856 const at::cuda::NVRTC& nvrtc() {
857   return at::globalContext().getNVRTC();
858 }
859 
860 // query codegen output arch and target
861 // TODO refactor so this function is usable both from jit and from aten
codegenOutputQuery(const cudaDeviceProp * const prop,int & cuda_major,int & cuda_minor,int & nvrtc_major,int & nvrtc_minor,bool & compile_to_sass)862 void codegenOutputQuery(
863     const cudaDeviceProp* const prop,
864     int& cuda_major,
865     int& cuda_minor,
866     int& nvrtc_major,
867     int& nvrtc_minor,
868     bool& compile_to_sass) {
869 #ifdef USE_ROCM
870   AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcVersion(&nvrtc_major, &nvrtc_minor));
871   cuda_major = prop->major;
872   cuda_minor = prop->minor;
873   compile_to_sass = false;
874 #else
875   AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcVersion(&nvrtc_major, &nvrtc_minor));
876   TORCH_CHECK(
877       nvrtc_major >= 6, "NVRTC versions less than 6 are not supported. Is: ", nvrtc_major);
878 
879   // Version supported by device
880   // Usually any lower version works too but is less efficient
881   using CUDAVersion = std::pair<int, int>;
882   const CUDAVersion nvrtc_version{nvrtc_major, nvrtc_minor};
883   const CUDAVersion dev_version{prop->major, prop->minor};
884   // Maximum version supported by the driver, cap dev_version to this
885   CUDAVersion max_dev_version;
886   if (nvrtc_major <= 7) { // 7 supports 2-5.x
887     max_dev_version = CUDAVersion(5, 0);
888   } else if (nvrtc_major <= 8) { // 8 supports 2-6.x
889     max_dev_version = CUDAVersion(6, 0);
890   } else if (nvrtc_major <= 9) { // 9 supports 3-7.2
891     max_dev_version = CUDAVersion(7, 2);
892   } else if (nvrtc_major <= 10) { // 10 supports 3-7.5
893     max_dev_version = CUDAVersion(7, 5);
894   } else if (nvrtc_version == CUDAVersion(11, 0)) { // 11.0 supports 3-8.0
895     max_dev_version = CUDAVersion(8, 0);
896   } else if (nvrtc_major == 11 && nvrtc_minor < 8) {
897     max_dev_version = CUDAVersion(8, 6);
898   } else {
899     // If the driver version is unknown (i.e. newer than this code)
900     // assume the driver supports this device
901     max_dev_version = dev_version;
902   }
903 
904   if (dev_version > max_dev_version) {
905     cuda_major = max_dev_version.first;
906     cuda_minor = max_dev_version.second;
907     // if we are clamping major/minor, sass is not compatible
908     compile_to_sass = false;
909   } else {
910     cuda_major = dev_version.first;
911     cuda_minor = dev_version.second;
912     compile_to_sass = true;
913   }
914 
915   #if defined(CUDA_VERSION) && CUDA_VERSION < 11010
916     // compile to sass is not allowed prior to CUDA 11.1
917     compile_to_sass = false;
918   #endif
919 #endif
920 }
921 
922 // TODO: another copy paste from jit, refactor so it's usable from both
923 // TODO: try making the CUcontext thread local to see if that improves performance - why is this slow?
initializeCudaContext()924 void initializeCudaContext() {
925   // lazily construct context if non-existing yet;
926   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
927   CUcontext pctx = nullptr;
928   AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx));
929   if (!pctx) {
930     std::unique_lock<std::mutex> cudaFreeMutexLock(
931         *(c10::cuda::getFreeMutex()));
932     cudaFree(nullptr);
933   }
934 }
935 
generate_code(const KernelDescriptor & desc,bool contiguous,bool dynamic_casting,BinaryFuncVariant scalar_pos,bool vectorized,int vec_size,bool return_by_ref)936 std::string generate_code(
937     const KernelDescriptor &desc,
938     bool contiguous,
939     bool dynamic_casting,
940     BinaryFuncVariant scalar_pos,
941     bool vectorized,
942     int vec_size,
943     bool return_by_ref) {
944   c10::SmallVector<std::string> extra_args_typenames(desc.extra_args_types.size());
945   for (auto i : c10::irange(extra_args_typenames.size())) {
946     extra_args_typenames[i] = typeName(desc.extra_args_types[i]);
947   }
948 
949   return generate_code(
950       desc.nInputs,
951       desc.nOutputs,
952       desc.f,
953       desc.name,
954       typeName(desc.f_inputs_type),
955       typeName(toOpMathType(desc.f_inputs_type)),
956       typeName(desc.result_type),
957       contiguous,
958       dynamic_casting,
959       scalar_pos,
960       extra_args_typenames,
961       vectorized,
962       vec_size,
963       return_by_ref);
964 }
965 
966 //FIXME - this are defined in Loops.cuh, but including Loops.cuh here would lead to circular includes Loops.cuh -> CUDALoops.cuh -> jit_utils.h -> Loops.cuh
967 #define THREAD_WORK_SIZE 4
968 constexpr int thread_work_size = THREAD_WORK_SIZE;
969 
generate_code(int nInputs,int nOutputs,const std::string & func_,const std::string & name,const std::string & f_inputs_type,const std::string & compute_type,const std::string & result_type,bool contiguous,bool dynamic_casting,BinaryFuncVariant scalar_pos,c10::SmallVector<std::string> & extra_args_typenames,bool vectorized,int vec_size,bool return_by_ref)970 std::string generate_code(
971     int nInputs,
972     int nOutputs,
973     const std::string& func_,
974     const std::string& name,
975     const std::string& f_inputs_type,
976     const std::string& compute_type,
977     const std::string& result_type,
978     bool contiguous,
979     bool dynamic_casting,
980     BinaryFuncVariant scalar_pos,
981     c10::SmallVector<std::string>& extra_args_typenames,
982     bool vectorized,
983     int vec_size,
984     bool return_by_ref) {
985   std::string func = func_;
986   at::jit::TemplateEnv env;
987 
988   env.s("index_type", "unsigned int");
989   env.s("nInputs", std::to_string(nInputs));
990   env.s("nOutputs", std::to_string(nOutputs));
991   env.s("scalar_type", f_inputs_type);
992   env.s("compute_type", compute_type);
993   env.s("functor", func);
994   env.s("name", name);
995   env.s("cmath_string", get_cmath_string());
996 
997   // Generate `extra_params` for function signature
998   // and `extra_args` for computation call if
999   // extra arguments to capture runtime state are passed.
1000   // (look at polygamma for example).
1001   std::string extra_params = "";
1002   std::string extra_args = "";
1003   for (size_t i = 0; i < extra_args_typenames.size(); i++) {
1004     auto type = std::string(extra_args_typenames[i]);
1005     auto name = "extra_arg_" + std::to_string(i);
1006     extra_params += "," + type + " " + name;
1007     extra_args += ", " + name;
1008   }
1009   env.s("extra_params", extra_params);
1010   env.s("extra_args", extra_args);
1011 
1012   std::stringstream declare_load_arrays;
1013   for (int i = 0; i < nInputs; i++) {
1014     // TODO these arrays are potentially of the different types, use function
1015     // traits to determine the types
1016     declare_load_arrays << f_inputs_type << " arg" << std::to_string(i)
1017                         << "[" << std::to_string(thread_work_size) << "];\n";
1018   }
1019   env.s("declare_load_arrays", declare_load_arrays.str());
1020 
1021   std::stringstream declare_store_arrays;
1022   for (int i = 0; i < nOutputs; i++) {
1023     declare_store_arrays << result_type << " out" << std::to_string(i)
1024                         << "[" << std::to_string(thread_work_size) << "];\n";
1025   }
1026   env.s("declare_store_arrays", declare_store_arrays.str());
1027 
1028   std::stringstream functor_args;
1029   if (scalar_pos == BinaryFuncVariant::NoScalar) {
1030     for (int i = 0; i < nInputs - 1; i++) {
1031       functor_args << "arg" << std::to_string(i) << "[j], ";
1032     }
1033     functor_args << "arg" << std::to_string(nInputs - 1) << "[j]";
1034   } else if (scalar_pos == BinaryFuncVariant::LhsScalar) {
1035     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(nInputs == 1);
1036     functor_args << "scalar_val, arg0[j]";
1037   } else { //RhsScalar
1038     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(nInputs == 1);
1039     functor_args << "arg0[j], scalar_val";
1040   }
1041   env.s("args", functor_args.str());
1042 
1043   std::string call_functor_template;
1044   if (return_by_ref) {  // return one or more outputs by reference
1045     bool need_temp_out = (compute_type != result_type);
1046     std::stringstream functor_outs;
1047     if (need_temp_out) {
1048       for (int i = 0; i < nOutputs - 1; i++) {
1049         functor_outs << "temp_out" << std::to_string(i) << ", ";
1050       }
1051       functor_outs << "temp_out" << std::to_string(nOutputs - 1);
1052     } else {
1053       for (int i = 0; i < nOutputs - 1; i++) {
1054         functor_outs << "out" << std::to_string(i) << "[j], ";
1055       }
1056       functor_outs << "out" << std::to_string(nOutputs - 1) << "[j]";
1057     }
1058     env.s("functor_outs", functor_outs.str());
1059 
1060     if (need_temp_out) {
1061       call_functor_template += "${compute_type} ${functor_outs};\n";
1062     }
1063 
1064     call_functor_template += "${name}<${compute_type}>(${args} ${extra_args}, ${functor_outs});\n";
1065 
1066     if (need_temp_out) {
1067       for (int i = 0; i < nOutputs; i++) {
1068         auto i_string = std::to_string(i);
1069         call_functor_template += "out" +i_string + "[j] = temp_out" + i_string + ";\n";
1070       }
1071     }
1072 
1073   } else {  // return by value for single output functor
1074     call_functor_template = "out0[j] = ${name}<${compute_type}>(${args} ${extra_args});";
1075   }
1076   env.s("call_functor", at::jit::CodeTemplate(call_functor_template).format(env));
1077 
1078   if (f_inputs_type == "at::Half" || result_type == "at::Half" ||
1079       f_inputs_type == "std::complex<at::Half>" ||
1080       result_type == "std::complex<at::Half>" || dynamic_casting) {
1081     // complex<Half> depends on complex<T> and Half dtypes.
1082     env.s("half_string", jiterator_half_support_literal);
1083   } else {
1084     env.s("half_string", "");
1085   }
1086   if (f_inputs_type == "at::BFloat16" || result_type == "at::BFloat16" || dynamic_casting) {
1087     env.s("bfloat16_string", jiterator_bfloat16_support_literal);
1088   } else {
1089     env.s("bfloat16_string", "");
1090   }
1091   // the definition of complex math functions is only needed when the compute type is complex
1092   // but the definition of std::complex is needed for dynamic casting even if the compute type is not complex
1093   if (f_inputs_type == "std::complex<float>" || result_type == "std::complex<float>" ||
1094       f_inputs_type == "std::complex<double>" || result_type == "std::complex<double>" ||
1095       f_inputs_type == "std::complex<at::Half>" || result_type == "std::complex<at::Half>") {
1096     // complex<Half> depends on complex<T> and Half dtypes.
1097     env.s("traits_string", get_traits_string_but_hiprtc_safe());
1098     env.s("complex_body_string", get_complex_body_string());
1099     env.s("complex_math_string", get_complex_math_string());
1100 #ifdef USE_ROCM
1101     // unhipify math functions, but only if std::complex is used.
1102     func = unhipify_math_functions(func);
1103     env.s("functor", func);
1104 #endif
1105   } else if (dynamic_casting) {
1106     env.s("traits_string", get_traits_string_but_hiprtc_safe());
1107     env.s("complex_body_string", get_complex_body_string());
1108     env.s("complex_math_string", "");
1109   } else {
1110     env.s("traits_string", "");
1111     env.s("complex_body_string", "");
1112     env.s("complex_math_string", "");
1113   }
1114   if (f_inputs_type == "std::complex<at::Half>" ||
1115       result_type == "std::complex<at::Half>" || dynamic_casting) {
1116     // dynamic_casting requires the definition of all types
1117     // include complex<at::Half>
1118     // Look at the definition of `StoreWithCast` and `LoadWithCast`.
1119     env.s("complex_half_body_string", get_complex_half_body_string());
1120   } else {
1121     env.s("complex_half_body_string", "");
1122   }
1123 
1124   env.s("load_support", load_support_literal);
1125 
1126   if (!vectorized) {
1127     if (!dynamic_casting) {
1128       env.s("loader", "LoadWithoutCast");
1129       env.s("storer", "StoreWithoutCast");
1130       env.s("dynamic_casting_string", no_dynamic_cast_support_literal);
1131     } else {
1132       env.s("loader", std::string("LoadWithCast<" + std::to_string(nInputs) + ">"));
1133       env.s("storer", std::string("StoreWithCast<" + std::to_string(nOutputs) + ">"));
1134       env.s("dynamic_casting_string", dynamic_cast_support_literal);
1135     }
1136 
1137     if (contiguous) {
1138       env.s("offset_calculator", "TrivialOffsetCalculator");
1139     } else {
1140       env.s("offset_calculator", "OffsetCalculator");
1141     }
1142 
1143     std::stringstream load_inputs;
1144     for (int i = 0; i < nInputs; i++) {
1145       auto i_string = std::to_string(i);
1146       load_inputs << "arg" << i_string << "[j] = l.load<" << f_inputs_type
1147                   << ">(data[" << std::to_string(i + nOutputs)
1148                   << "], input_offsets[" << i_string << "], " << i_string
1149                   << ");\n";
1150     }
1151     env.s("load_inputs", load_inputs.str());
1152 
1153     std::stringstream store_outputs;
1154     for (int i = 0; i < nOutputs; i++) {
1155       auto i_string = std::to_string(i);
1156       store_outputs << "s.store<" << result_type
1157                     << ">(out" << i_string << "[j], data[" << i_string
1158                     << "], output_offsets[" << i_string << "], " << i_string
1159                     << ");\n";
1160     }
1161     env.s("store_outputs", store_outputs.str());
1162 
1163     static auto cuda_template = at::jit::CodeTemplate(
1164       jit_preamble + jit_common_types + offset_calc_template + jit_code_template + jit_epilogue);
1165     const auto code = cuda_template.format(env);
1166     return code;
1167   }
1168 
1169   // vectorized case
1170   env.s("vec_size", std::to_string(vec_size));
1171   env.s("result_type", result_type);
1172 
1173   std::stringstream vector_inputs;
1174   for (const auto i : c10::irange(nInputs)){
1175     auto i_string = std::to_string(i);
1176     vector_inputs << "auto * input" << i_string <<
1177         " = reinterpret_cast<const scalar_t*>(data[" << i_string << "+" << nOutputs << "])" <<
1178         " + block_work_size * idx;\n";
1179   }
1180   env.s("vector_inputs", vector_inputs.str());
1181 
1182   std::stringstream vector_outputs;
1183   for (const auto i : c10::irange(nOutputs)){
1184     auto i_string = std::to_string(i);
1185     vector_outputs << "vec_t_output* to_" << i_string <<
1186     " = reinterpret_cast<vec_t_output*>(data[" << i_string << "])" <<
1187     " + block_work_size / vec_size * idx;\n";
1188   }
1189   env.s("vector_outputs", vector_outputs.str());
1190 
1191   std::stringstream load_vectorized_inputs;
1192   for (const auto i : c10::irange(nInputs)) {
1193     auto i_string = std::to_string(i);
1194     load_vectorized_inputs << "const auto vec" << i_string << " = load_vector<vec_size>("
1195                            << "input" << i_string << ", thread_idx);\n";
1196     load_vectorized_inputs << "#pragma unroll\n";
1197     load_vectorized_inputs << "for (int j=0; j < vec_size; j++){\n";
1198     load_vectorized_inputs << "  arg" << i_string << "[vec_size * i + j] = vec" << i_string << ".val[j];\n";
1199     load_vectorized_inputs << "}\n";
1200   }
1201   env.s("load_vectorized_inputs", load_vectorized_inputs.str());
1202 
1203   std::stringstream store_vectorized_outputs;
1204   for (const auto i : c10::irange(nOutputs)) {
1205     auto i_string = std::to_string(i);
1206     store_vectorized_outputs << "#pragma unroll\n";
1207     store_vectorized_outputs << "for (int j=0; j<vec_size; j++){\n";
1208     store_vectorized_outputs <<   "v.val[j] = out" << i_string << "[vec_size * i + j];\n";
1209     store_vectorized_outputs << "}\n";
1210     store_vectorized_outputs << "to_"<< i_string << "[thread_idx] = v;\n";
1211   }
1212   env.s("store_vectorized_outputs", store_vectorized_outputs.str());
1213 
1214   std::stringstream load_unrolled_inputs;
1215   for (const auto i: c10::irange(nInputs)){
1216     auto i_string = std::to_string(i);
1217     load_unrolled_inputs << "arg" << i_string << "[j] = load<" << f_inputs_type
1218       << ">(data[" << std::to_string(i + nOutputs) << "], linear_idx);\n";
1219   }
1220   env.s("load_unrolled_inputs", load_unrolled_inputs.str());
1221 
1222   std::stringstream store_unrolled_outputs;
1223   for (const auto i : c10::irange(nOutputs)) {
1224     auto i_string = std::to_string(i);
1225     store_unrolled_outputs << "store<" << result_type << ">(out" << i_string
1226       << "[j], data[" << i_string << "], linear_idx);\n";
1227   }
1228   env.s("store_unrolled_outputs", store_unrolled_outputs.str());
1229 
1230   static auto cuda_template = at::jit::CodeTemplate(
1231     jit_preamble + jit_common_types + jit_vectorized_code_template + jit_epilogue);
1232   const auto code = cuda_template.format(env);
1233   return code;
1234 }
1235 
1236 // Creates directories recursively
_r_mkdir(const std::string & dir)1237 bool _r_mkdir(const std::string& dir) {
1238   // Check if current dir exists
1239   const char* p_dir = dir.c_str();
1240   const bool dir_exists = (access(p_dir, F_OK) == 0);
1241   if (dir_exists) {
1242     return true;
1243   }
1244 
1245   // Try to create current directory
1246 #ifdef _WIN32
1247   int ret = _mkdir(dir.c_str());
1248 #else
1249   int ret = mkdir(dir.c_str(), S_IRWXU | S_IRWXG | S_IRWXO);
1250 #endif
1251   // Success
1252   if (ret == 0) {
1253     return true;
1254   }
1255 
1256   // Find folder separator and check if we are at the top
1257   auto  pos = dir.find_last_of("/\\");
1258   if (pos == std::string::npos) {
1259     return false;
1260   }
1261 
1262   // Try to create parent directory
1263   if (!(_r_mkdir(dir.substr(0, pos)))) {
1264     return false;
1265   }
1266 
1267   // Try to create complete path again
1268 #ifdef _WIN32
1269   ret = _mkdir(dir.c_str());
1270 #else
1271   ret = mkdir(dir.c_str(), S_IRWXU | S_IRWXG | S_IRWXO);
1272 #endif
1273   return ret == 0;
1274 }
1275 
1276 // Creates directories recursively assuming that base exists
r_mkdir_with_base(std::string & base,std::string & dir)1277 bool r_mkdir_with_base(std::string& base, std::string& dir){
1278   const char* p_base = base.c_str();
1279   const bool base_exists = (access(p_base, F_OK) == 0);
1280   if (!base_exists) {
1281     return false;
1282   }
1283 
1284   // remove trailing '/' or '\\'
1285   if ((base[base.size()-1]=='/') || base[base.size()-1]=='\\') {
1286     base.pop_back();
1287   }
1288   if ((dir[dir.size()-1]=='/') || dir[dir.size()-1]=='\\') {
1289     dir.pop_back();
1290   }
1291 
1292   return _r_mkdir(base+dir);
1293 
1294 }
1295 
load_code_template(const std::string & path)1296 std::string load_code_template(const std::string& path) {
1297   std::ifstream ifs{path};
1298   std::string s{
1299     std::istreambuf_iterator<char>(ifs),
1300     std::istreambuf_iterator<char>()};
1301   return s;
1302 }
1303 
generate_reduction_code(const KernelDescriptor & desc,int vt0,bool contiguous,bool vectorized,int vec_size,int max_threads_codegen)1304 std::string generate_reduction_code(
1305     const KernelDescriptor &desc,
1306     int vt0,
1307     bool contiguous,
1308     bool vectorized,
1309     int vec_size,
1310     int max_threads_codegen) {
1311   TORCH_INTERNAL_ASSERT(desc.nInputs == 1);
1312   TORCH_INTERNAL_ASSERT(desc.extra_args_types.size() == 0);
1313 
1314   return generate_reduction_code(
1315       desc.nOutputs,
1316       desc.f,
1317       desc.name,
1318       vt0,
1319       typeName(desc.f_inputs_type),
1320       typeName(toOpMathType(desc.f_inputs_type)),
1321       typeName(desc.result_type),
1322       contiguous,
1323       vectorized,
1324       vec_size,
1325       max_threads_codegen
1326     );
1327 }
1328 
generate_reduction_code(int nOutputs,const std::string & func_,const std::string & name,const int vt0,const std::string & f_inputs_type,const std::string & reduction_accum_type,const std::string & result_type,bool contiguous,bool vectorized,int vec_size,int max_threads_codegen)1329 std::string generate_reduction_code(
1330     int nOutputs,
1331     const std::string& func_,
1332     const std::string& name,
1333     const int vt0,
1334     const std::string& f_inputs_type,
1335     const std::string& reduction_accum_type,
1336     const std::string& result_type,
1337     bool contiguous,
1338     bool vectorized,
1339     int vec_size,
1340     int max_threads_codegen) {
1341       std::string func = func_;
1342       at::jit::TemplateEnv env;
1343       env.s("index_type", "unsigned int");
1344       env.s("scalar_type", f_inputs_type);
1345       env.s("result_type", result_type);
1346       env.s("reduction_accum_type", reduction_accum_type);
1347       env.s("vt0", std::to_string(vt0));
1348       env.s("name", name);
1349       env.s("max_threads_lb", std::to_string(max_threads_codegen));
1350       // reductions don't support dynamic casting, so the only way to get nonstandard types
1351       // is through input
1352       if (f_inputs_type == "at::Half" || f_inputs_type == "std::complex<at::Half>") {
1353         // complex<Half> depends on complex<T> and Half dtypes.
1354         env.s("half_string", jiterator_half_support_literal);
1355       } else {
1356         env.s("half_string", "");
1357       }
1358       if (f_inputs_type == "at::BFloat16") {
1359         env.s("bfloat16_string", jiterator_bfloat16_support_literal);
1360       } else {
1361         env.s("bfloat16_string", "");
1362       }
1363       if (f_inputs_type == "std::complex<float>" ||
1364           f_inputs_type == "std::complex<double>" ||
1365           f_inputs_type == "std::complex<at::Half>" ) {
1366         // complex<Half> depends on complex<T> and Half dtypes.
1367         env.s("traits_string", get_traits_string_but_hiprtc_safe());
1368         env.s("complex_body_string", get_complex_body_string());
1369         env.s("complex_math_string", get_complex_math_string());
1370         env.s("complex", std::to_string(1));
1371 #ifdef USE_ROCM
1372         // unhipify math functions, but only if std::complex is used.
1373         func = unhipify_math_functions(func);
1374 #endif
1375       } else {
1376         env.s("traits_string", "");
1377         env.s("complex_body_string", "");
1378         env.s("complex_math_string", "");
1379         env.s("complex", std::to_string(0));
1380       }
1381       if (f_inputs_type == "std::complex<at::Half>") {
1382         env.s("complex_half_body_string", get_complex_half_body_string());
1383       } else {
1384         env.s("complex_half_body_string", "");
1385       }
1386       env.s("cmath_string", get_cmath_string());
1387       env.s("functor", func);
1388       env.s("output_vec_size", std::to_string(vec_size));
1389       static auto cuda_template = at::jit::CodeTemplate(
1390         jit_preamble + jit_common_types + offset_calc_template + get_reduction_template() + jit_epilogue);
1391       const auto code = cuda_template.format(env);
1392       return code;
1393 }
1394 
1395 // Acquires (possibly creating) the kernel cache directory
get_cache_dir()1396 std::optional<std::string> get_cache_dir() {
1397   // If the environment variable USE_TORCH_KERNEL_CACHE is set to "0" then no persistent cache is used
1398   const char* uptkc = std::getenv("USE_PYTORCH_KERNEL_CACHE");
1399   const bool use_kernel_cache = (uptkc == nullptr) ? true : std::strcmp(uptkc, "0");
1400 
1401   if (!use_kernel_cache) {
1402     return {};
1403   }
1404 
1405   // Cache path comes from PYTORCH_KERNEL_CACHE_PATH, then TEMP (Windows) or XDG_CACHE_HOME (Linux), then HOME environment variables
1406   std::string cache_dir;
1407   char* ptkcp = std::getenv("PYTORCH_KERNEL_CACHE_PATH");
1408   // Create kernel_cache_dir if needed as we do not want to create the base directory passed by the user
1409   std::string kernels_cache_dir = "";
1410   if (ptkcp != nullptr) {
1411     cache_dir = std::string(ptkcp);
1412   } else {
1413 #ifdef _WIN32
1414     ptkcp = std::getenv("TEMP");
1415 #else
1416     // USES XDG_CACHE_HOME if it's set
1417     ptkcp = std::getenv("XDG_CACHE_HOME");
1418 #endif
1419     if (ptkcp != nullptr) {
1420       kernels_cache_dir = "/torch/kernels";
1421       cache_dir = std::string(ptkcp) + kernels_cache_dir;
1422     } else {
1423       // Falls back to HOME/.cache
1424       ptkcp = std::getenv("HOME");
1425       if (ptkcp == nullptr) {
1426         TORCH_WARN_ONCE("No PYTORCH_KERNEL_CACHE_PATH or HOME environment variable set!",
1427                         " This disables kernel caching.");
1428         return {};
1429       } else {
1430         kernels_cache_dir = "/.cache/torch/kernels";
1431         cache_dir = std::string(ptkcp) + kernels_cache_dir;
1432       }
1433     }
1434   }
1435 
1436   // Creates the cache directory if it does not exist
1437   const char* p_cache_dir = cache_dir.c_str();
1438   const bool cache_dir_exists = (access(p_cache_dir, F_OK) == 0);
1439   if (!cache_dir_exists) {
1440     std::string s_ptkcp = std::string(ptkcp);
1441     if (!r_mkdir_with_base(s_ptkcp, kernels_cache_dir)) {
1442       TORCH_WARN_ONCE("Specified kernel cache directory could not be created! This disables kernel caching.",
1443                       " Specified directory is ", cache_dir, ".",
1444                       " This warning will appear only once per process.");
1445       return {};
1446     }
1447   }
1448 
1449   // Checks that the cache directory is readable and writable
1450   const bool cache_dir_readable = (access(p_cache_dir, R_OK) == 0);
1451   if (!cache_dir_readable) {
1452     TORCH_WARN_ONCE("Specified kernel cache directory is not readable! This disables kernel caching.",
1453                     " Specified directory is ", cache_dir, ".",
1454                     " This warning will appear only once per process.");
1455     return {};
1456   }
1457 
1458   const bool cache_dir_writable = (access(p_cache_dir, W_OK) == 0);
1459   if (!cache_dir_writable) {
1460     TORCH_WARN_ONCE("Specified kernel cache directory is not writable! This disables kernel caching.",
1461                     " Specified directory is ", cache_dir, ".",
1462                     " This warning will appear only once per process.");
1463     return {};
1464   }
1465 
1466   return cache_dir;
1467 }
1468 
1469 // Compiles the kernel, or acquires if from the cache if caching
jit_pwise_function(const std::string & code,const std::string & kernel_name)1470 NvrtcFunction jit_pwise_function(
1471     const std::string& code,
1472     const std::string& kernel_name) {
1473   initializeCudaContext();
1474   // Acquires CUDA and nvrtc versions and whether we're compiling to ptx or SASS
1475   const cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
1476   int cuda_major = 0, cuda_minor = 0, nvrtc_major = 0, nvrtc_minor = 0;
1477   bool compile_to_sass = false;
1478   at::cuda::jit::codegenOutputQuery(
1479     prop, cuda_major, cuda_minor, nvrtc_major, nvrtc_minor, compile_to_sass);
1480 
1481   // Objects used whether loading from the cache or jit compiling
1482   const auto& nvrtc = at::globalContext().getNVRTC();
1483   NvrtcFunction compiled_kernel_;
1484   std::string name = kernel_name + "_kernel";
1485 
1486   static const std::optional<std::string> cache_dir = get_cache_dir();
1487 
1488   std::string file_path;
1489   if (cache_dir.has_value()) {
1490     // Attemps to read from the cache.
1491     // Cubin name is <kernel name>_arch<major>.<minor>_nvrtc<major>.<minor>_<ptx or sass>_<program length>_<string hash>
1492     // Note that the SHA1 hash used in the file name is NOT the SHA1 hash of the file's contents,
1493     //   because we hash on the CUDA code, but we save the compiled ptx or sass
1494 
1495     // Acquires SHA1 hash
1496     c10::sha1 sha1_hash{code};
1497     const auto hash_code = sha1_hash.str();
1498 
1499     // Constructs file path by appending constructed cubin name to cache path
1500     std::stringstream ss;
1501     ss << *cache_dir << "/";
1502     ss << kernel_name;
1503 #ifdef USE_ROCM
1504     ss << "_arch" << prop->gcnArchName;
1505 #else
1506     ss << "_arch" << cuda_major << "." << cuda_minor;
1507 #endif
1508     ss << "_nvrtc" << nvrtc_major << "." << nvrtc_minor;
1509     ss << (compile_to_sass ? "_sass" : "_ptx");
1510     ss << "_" << code.length();
1511     ss << "_" << hash_code;
1512     file_path = ss.str();
1513 
1514     std::ifstream readin{file_path, std::ios::in | std::ifstream::binary};
1515     if (readin.fail()) {
1516       // NOTE: this does not warn because the file might not exist
1517       // TODO: consider if this should explicitly check for the file's existence or not to throw
1518       //   an informative warning
1519       readin.close();
1520     } else {
1521       // TODO: try passing the "mapped" file directly to cuModuleLoadCall instead of using an intermediate buffer
1522       std::vector<char> buffer(std::istreambuf_iterator<char>(readin), {});
1523       AT_CUDA_DRIVER_CHECK(nvrtc.cuModuleLoadData(&(compiled_kernel_.module), buffer.data()));
1524       AT_CUDA_DRIVER_CHECK(
1525         nvrtc.cuModuleGetFunction(&(compiled_kernel_.function), compiled_kernel_.module, name.c_str()));
1526       readin.close();
1527       return compiled_kernel_;
1528     }
1529   }
1530 
1531   // Just-in-time compiles the program
1532 
1533   // Creates the NVRTC program
1534   nvrtcProgram program;
1535   AT_CUDA_NVRTC_CHECK(nvrtc.nvrtcCreateProgram(
1536       &program, code.c_str(), nullptr, 0, nullptr, nullptr));
1537 
1538 #ifdef USE_ROCM
1539   std::vector<const char*> args = {"--std=c++17"};
1540 #else
1541   // Constructs nvrtc build arguments
1542   // CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_)
1543   // which gives better backwards compatibility to work on older driver,
1544   // (since older driver doesn't necessarily recognize PTX emitted by new
1545   // toolkit);
1546   // Meanwhile, for forward compatibility (future device with
1547   // `unsupported_arch==True`), since SASS are not necessarily compatible,
1548   // we fallback to PTX instead.
1549   const std::string compute = std::string("--gpu-architecture=") +
1550       (compile_to_sass ? "sm_" : "compute_") + std::to_string(cuda_major) +
1551       std::to_string(cuda_minor);
1552   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1553   std::vector<const char*> args = {
1554       "--std=c++17", compute.c_str(), "-default-device"};
1555 #endif
1556 
1557   #ifndef NDEBUG
1558     // Add line info to generated kernels
1559     args.push_back("-lineinfo");
1560   #else
1561     // Avoid excessive register usage from assertion
1562     args.push_back("-DNDEBUG");
1563   #endif
1564 
1565   const auto compilation_result =
1566       nvrtc.nvrtcCompileProgram(program, args.size(), args.data());
1567 
1568   // Throws an error on compilation failure
1569   if (compilation_result != NVRTC_SUCCESS) {
1570     size_t logsize;
1571     AT_CUDA_NVRTC_CHECK(nvrtc.nvrtcGetProgramLogSize(program, &logsize));
1572     std::string log(logsize, '\0');
1573     AT_CUDA_NVRTC_CHECK(nvrtc.nvrtcGetProgramLog(program, &log[0]));
1574     throw std::runtime_error(code + log);
1575   }
1576 
1577   size_t ptx_size = 0;
1578   std::vector<char> ptx;
1579   #if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
1580     // compile_to_sass determines whether we are generating SASS or PTX, hence
1581     // the different API.
1582     const auto getSize = compile_to_sass
1583         ? at::globalContext().getNVRTC().nvrtcGetCUBINSize
1584         : at::globalContext().getNVRTC().nvrtcGetPTXSize;
1585     const auto getFunc = compile_to_sass
1586         ? at::globalContext().getNVRTC().nvrtcGetCUBIN
1587         : at::globalContext().getNVRTC().nvrtcGetPTX;
1588   #else
1589     const auto getSize = at::globalContext().getNVRTC().nvrtcGetPTXSize;
1590     const auto getFunc = at::globalContext().getNVRTC().nvrtcGetPTX;
1591   #endif
1592 
1593   AT_CUDA_NVRTC_CHECK(getSize(program, &ptx_size));
1594   ptx.resize(ptx_size);
1595   AT_CUDA_NVRTC_CHECK(getFunc(program, ptx.data()));
1596 
1597   AT_CUDA_DRIVER_CHECK(nvrtc.cuModuleLoadData(&(compiled_kernel_.module), ptx.data()));
1598 
1599   AT_CUDA_DRIVER_CHECK(
1600       nvrtc.cuModuleGetFunction(&(compiled_kernel_.function), compiled_kernel_.module, name.c_str()));
1601   // TODO: use guards to avoid leaking
1602   AT_CUDA_NVRTC_CHECK(nvrtc.nvrtcDestroyProgram(&program));
1603 
1604   if (cache_dir.has_value()) {
1605     // Writes the program to the cache if caching
1606     // NOTE: Actually writes to a per-process temporary file to avoid multi-process contention.
1607     //   The temporary file is then renamed to the actual file.
1608     //   If the actual file already exists then the rename may fail or replace the actual file,
1609     //     the behavior is implementation-specific.
1610     //   Files replaced through this process should remain extant if they are being read because
1611     //     of UNIX filesystem properties, but this behavior is unverified and may require
1612     //     additional review in the future.
1613     // TODO: In C++17 we should be able to use the filesystem header.
1614     const auto pid = getpid();
1615     std::stringstream tmp_file_path_ss;
1616     tmp_file_path_ss << file_path << "_tmp_" << pid;
1617     const std::string tmp_file_path = tmp_file_path_ss.str();
1618     std::ofstream cubin(tmp_file_path, std::ios::out | std::ofstream::binary);
1619     if (cubin.fail()) {
1620       TORCH_WARN_ONCE("Failed to write temporarily kernel cache file!",
1621                       " File path was ", tmp_file_path, ".",
1622                       " This warning will only appear once per process.");
1623     } else {
1624       std::copy(ptx.begin(), ptx.end(), std::ostreambuf_iterator<char>(cubin));
1625       if (std::rename(tmp_file_path.c_str(), file_path.c_str()) != 0) {
1626         // Removes tmp file if the rename failed
1627         std::remove(tmp_file_path.c_str());
1628       }
1629     }
1630     cubin.close();
1631   }
1632 
1633   return compiled_kernel_;
1634 }
1635 
1636 // TODO: may need/want to initialize CUDA context here (refactor into nvrtc call)
launch_jitted_pwise_function(NvrtcFunction function,void * args[],const dim3 nBlocks,const dim3 kBlockSize,const int smem)1637 void launch_jitted_pwise_function(
1638     NvrtcFunction function,
1639     void* args[],
1640     const dim3 nBlocks,
1641     const dim3 kBlockSize,
1642     const int smem) {
1643   initializeCudaContext();
1644   const auto& nvrtc = at::globalContext().getNVRTC();
1645   // Launches kernel on current stream
1646   auto stream = at::cuda::getCurrentCUDAStream();
1647   AT_CUDA_DRIVER_CHECK(nvrtc.cuLaunchKernel(
1648     function.function,
1649     nBlocks.x,
1650     nBlocks.y,
1651     nBlocks.z,
1652     kBlockSize.x,
1653     kBlockSize.y,
1654     kBlockSize.z,
1655     smem,
1656     stream,
1657     args,
1658     nullptr));
1659 }
1660 
1661 } // at::cuda::jit
1662