xref: /aosp_15_r20/external/pytorch/aten/src/ATen/jiterator_macros.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/macros/Macros.h>
3 #include <string>
4 
5 #define JITERATOR_HOST_DEVICE C10_HOST_DEVICE
6 #if defined(_MSC_VER) && defined(__CUDACC__)
7 // NVRTC on Windows errors if __host__ __device__ attribute is
8 // present on kernel.
9 // error: attribute "__host__" does not apply here
10 // error: attribute "__device__" does not apply here
11 #define JITERATOR_HOST_DEVICE
12 #endif
13 
14 // jiterator_also_stringify_as macro is used to define code (for CPU/ROCm)
15 // and generate code string for `jiterator` (only when compiling for CUDA).
16 // Usage :
17 //      jiterator_also_stringify_as(
18 //          jiterator_code(template <typename T> T identity(T x) { return x; }),
19 //          identity_string);
20 // This will define the template `identity` as present in code and
21 // also define `std::string identity_string` with the code as the string
22 // if this is being compiled for CUDA.
23 
24 // `jiterator_code` macro is to deal with `,` in the kernel code.
25 // These `,`s confuse the preprocessor into thinking we are passing
26 // multiple arguments to the macro.
27 #define jiterator_code(...) __VA_ARGS__
28 #if defined(__CUDACC__) || defined(__HIPCC__)
29 // CPU and CUDA and ROCm case
30 #define stringify_code(...) #__VA_ARGS__
31 #define jiterator_also_stringify_as(code, str_name) \
32   code /* define the function */                    \
33       const std::string str_name = std::string(stringify_code(code));
34 #else
35 // CPU only or CPU and ROCm case
36 // Only needs the function
37 #define jiterator_also_stringify_as(code, str_name) code
38 #endif
39