xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/DeviceUtils.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cuda.h>
4 #include <c10/util/complex.h>
5 #include <c10/util/Half.h>
6 
ACTIVE_MASK()7 __device__ __forceinline__ unsigned int ACTIVE_MASK()
8 {
9 #if !defined(USE_ROCM)
10     return __activemask();
11 #else
12 // will be ignored anyway
13     return 0xffffffff;
14 #endif
15 }
16 
WARP_SYNC(unsigned mask=0xffffffff)17 __device__ __forceinline__ void WARP_SYNC(unsigned mask = 0xffffffff) {
18 #if !defined(USE_ROCM)
19   return __syncwarp(mask);
20 #endif
21 }
22 
23 #if defined(USE_ROCM)
WARP_BALLOT(int predicate)24 __device__ __forceinline__ unsigned long long int WARP_BALLOT(int predicate)
25 {
26 return __ballot(predicate);
27 }
28 #else
WARP_BALLOT(int predicate,unsigned int mask=0xffffffff)29 __device__ __forceinline__ unsigned int WARP_BALLOT(int predicate, unsigned int mask = 0xffffffff)
30 {
31 #if !defined(USE_ROCM)
32     return __ballot_sync(mask, predicate);
33 #else
34     return __ballot(predicate);
35 #endif
36 }
37 #endif
38 
39 template <typename T>
WARP_SHFL_XOR(T value,int laneMask,int width=warpSize,unsigned int mask=0xffffffff)40 __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
41 {
42 #if !defined(USE_ROCM)
43     return __shfl_xor_sync(mask, value, laneMask, width);
44 #else
45     return __shfl_xor(value, laneMask, width);
46 #endif
47 }
48 
49 template <typename T>
WARP_SHFL(T value,int srcLane,int width=warpSize,unsigned int mask=0xffffffff)50 __device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = warpSize, unsigned int mask = 0xffffffff)
51 {
52 #if !defined(USE_ROCM)
53     return __shfl_sync(mask, value, srcLane, width);
54 #else
55     return __shfl(value, srcLane, width);
56 #endif
57 }
58 
59 template <typename T>
WARP_SHFL_UP(T value,unsigned int delta,int width=warpSize,unsigned int mask=0xffffffff)60 __device__ __forceinline__ T WARP_SHFL_UP(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
61 {
62 #if !defined(USE_ROCM)
63     return __shfl_up_sync(mask, value, delta, width);
64 #else
65     return __shfl_up(value, delta, width);
66 #endif
67 }
68 
69 template <typename T>
WARP_SHFL_DOWN(T value,unsigned int delta,int width=warpSize,unsigned int mask=0xffffffff)70 __device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
71 {
72 #if !defined(USE_ROCM)
73     return __shfl_down_sync(mask, value, delta, width);
74 #else
75     return __shfl_down(value, delta, width);
76 #endif
77 }
78 
79 #if defined(USE_ROCM)
80 template<>
WARP_SHFL_DOWN(int64_t value,unsigned int delta,int width,unsigned int mask)81 __device__ __forceinline__ int64_t WARP_SHFL_DOWN<int64_t>(int64_t value, unsigned int delta, int width , unsigned int mask)
82 {
83   //(HIP doesn't support int64_t). Trick from https://devblogs.nvidia.com/faster-parallel-reductions-kepler/
84   int2 a = *reinterpret_cast<int2*>(&value);
85   a.x = __shfl_down(a.x, delta);
86   a.y = __shfl_down(a.y, delta);
87   return *reinterpret_cast<int64_t*>(&a);
88 }
89 #endif
90 
91 template<>
WARP_SHFL_DOWN(c10::Half value,unsigned int delta,int width,unsigned int mask)92 __device__ __forceinline__ c10::Half WARP_SHFL_DOWN<c10::Half>(c10::Half value, unsigned int delta, int width, unsigned int mask)
93 {
94   return c10::Half(WARP_SHFL_DOWN<unsigned short>(value.x, delta, width, mask), c10::Half::from_bits_t{});
95 }
96 
97 template <typename T>
WARP_SHFL_DOWN(c10::complex<T> value,unsigned int delta,int width=warpSize,unsigned int mask=0xffffffff)98 __device__ __forceinline__ c10::complex<T> WARP_SHFL_DOWN(c10::complex<T> value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
99 {
100 #if !defined(USE_ROCM)
101     return c10::complex<T>(
102         __shfl_down_sync(mask, value.real_, delta, width),
103         __shfl_down_sync(mask, value.imag_, delta, width));
104 #else
105     return c10::complex<T>(
106         __shfl_down(value.real_, delta, width),
107         __shfl_down(value.imag_, delta, width));
108 #endif
109 }
110 
111 /**
112  * For CC 3.5+, perform a load using __ldg
113  */
114 template <typename T>
doLdg(const T * p)115 __device__ __forceinline__ T doLdg(const T* p) {
116 #if __CUDA_ARCH__ >= 350 && !defined(USE_ROCM)
117   return __ldg(p);
118 #else
119   return *p;
120 #endif
121 }
122