xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/KernelUtils.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/cuda/Atomic.cuh>
3 
4 #if !(defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
5 #include <cuda_bf16.h>
6 #endif
7 
8 namespace at {
9 namespace native {
10 
11 __device__ __forceinline__ size_t
idx(const size_t nc,const size_t height,const size_t width,const size_t h,const size_t w)12 idx(const size_t nc,
13     const size_t height,
14     const size_t width,
15     const size_t h,
16     const size_t w) {
17   return (nc * height + h) * width + w;
18 }
19 
20 // for channels-last
21 __device__ __forceinline__ size_t
idx_cl(const size_t n,const size_t h,const size_t w,const size_t c,const size_t height,const size_t width,const size_t channel)22 idx_cl(
23   const size_t n, const size_t h, const size_t w, const size_t c,
24   const size_t height, const size_t width, const size_t channel
25 ) {
26   return ((n * height + h) * width + w) * channel + c;
27 }
28 
29 // fastSpecializedAtomicAdd (and fastAtomicAdd) are an optimization
30 // that speed up half-precision atomics.  The situation with half
31 // precision atomics is that we have a slow __half atomic, and
32 // a fast vectored __half2 atomic (this can be worth up to a 6x
33 // speedup, see https://github.com/pytorch/pytorch/pull/21879).
34 // We can convert a __half atomic into a __half2 atomic by simply
35 // pairing the __half with a zero entry on the left/right depending
36 // on alignment... but only if this wouldn't cause an out of bounds
37 // access!  Thus, you must specify tensor and numel so we can check
38 // if you would be out-of-bounds and use a plain __half atomic if
39 // you would be.
40 template <
41     typename scalar_t,
42     typename index_t,
43     typename std::enable_if<std::is_same<c10::Half, scalar_t>::value>::type* =
44         nullptr>
fastSpecializedAtomicAdd(scalar_t * tensor,index_t index,const index_t numel,scalar_t value)45 __device__ __forceinline__ void fastSpecializedAtomicAdd(
46     scalar_t* tensor,
47     index_t index,
48     const index_t numel,
49     scalar_t value) {
50 #if (                      \
51     (defined(USE_ROCM)) || \
52     (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
53   gpuAtomicAddNoReturn(
54       reinterpret_cast<at::Half*>(tensor) + index,
55       static_cast<at::Half>(value));
56 #else
57   // Accounts for the chance tensor falls on an odd 16 bit alignment (ie, not 32 bit aligned)
58   __half* target_addr = reinterpret_cast<__half*>(tensor + index);
59   bool low_byte = (reinterpret_cast<std::uintptr_t>(target_addr) % sizeof(__half2) == 0);
60 
61   if (low_byte && index < (numel - 1)) {
62     __half2 value2;
63     value2.x = static_cast<__half>(value);
64     value2.y = __int2half_rz(0);
65     atomicAdd(reinterpret_cast<__half2*>(target_addr), value2);
66 
67   } else if (!low_byte && index > 0) {
68     __half2 value2;
69     value2.x = __int2half_rz(0);
70     value2.y = static_cast<__half>(value);
71     atomicAdd(reinterpret_cast<__half2*>(target_addr - 1), value2);
72 
73   } else {
74     atomicAdd(
75         reinterpret_cast<__half*>(tensor) + index, static_cast<__half>(value));
76   }
77 #endif
78 }
79 
80 template <
81     typename scalar_t,
82     typename index_t,
83     typename std::enable_if<std::is_same<c10::BFloat16, scalar_t>::value>::type* =
84         nullptr>
fastSpecializedAtomicAdd(scalar_t * tensor,index_t index,const index_t numel,scalar_t value)85 __device__ __forceinline__ void fastSpecializedAtomicAdd(
86     scalar_t* tensor,
87     index_t index,
88     const index_t numel,
89     scalar_t value) {
90 #if (                      \
91     (defined(USE_ROCM)) || \
92     (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
93   gpuAtomicAddNoReturn(
94       reinterpret_cast<at::BFloat16*>(tensor) + index,
95       static_cast<at::BFloat16>(value));
96 #else
97   // Accounts for the chance tensor falls on an odd 16 bit alignment (ie, not 32 bit aligned)
98   __nv_bfloat16* target_addr = reinterpret_cast<__nv_bfloat16*>(tensor + index);
99   bool low_byte = (reinterpret_cast<std::uintptr_t>(target_addr) % sizeof(__nv_bfloat162) == 0);
100 
101   if (low_byte && index < (numel - 1)) {
102     __nv_bfloat162 value2;
103     value2.x = *reinterpret_cast<__nv_bfloat16*>(&value);
104     value2.y = __int2bfloat16_rz(0);
105     atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr), value2);
106 
107   } else if (!low_byte && index > 0) {
108     __nv_bfloat162 value2;
109     value2.x = __int2bfloat16_rz(0);
110     value2.y = *reinterpret_cast<__nv_bfloat16*>(&value);
111     atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr - 1), value2);
112 
113   } else {
114     atomicAdd(
115         reinterpret_cast<__nv_bfloat16*>(tensor) + index, *reinterpret_cast<__nv_bfloat16*>(&value));
116   }
117 #endif
118 }
119 
120 
121 template <
122     typename scalar_t,
123     typename index_t,
124     typename std::enable_if<!std::is_same<c10::Half, scalar_t>::value && !std::is_same<c10::BFloat16, scalar_t>::value >::type* =
125         nullptr>
fastSpecializedAtomicAdd(scalar_t * tensor,index_t index,const index_t numel,scalar_t value)126 __device__ __forceinline__ void fastSpecializedAtomicAdd(
127     scalar_t* tensor,
128     index_t index,
129     const index_t numel,
130     scalar_t value) {
131   gpuAtomicAddNoReturn(tensor + index, value);
132 }
133 
134 template <class scalar_t, class index_t>
fastAtomicAdd(scalar_t * tensor,index_t index,const index_t numel,scalar_t value,bool fast_atomics)135 __device__ __forceinline__ void fastAtomicAdd(
136     scalar_t* tensor,
137     index_t index,
138     const index_t numel,
139     scalar_t value,
140     bool fast_atomics) {
141   if (fast_atomics) {
142     fastSpecializedAtomicAdd(tensor, index, numel, value);
143   } else {
144     gpuAtomicAddNoReturn(tensor + index, value);
145   }
146 }
147 
148 } // namespace native
149 } // namespace at
150