xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 #pragma once
9 
10 #include <cutlass/arch/mma.h>
11 
12 ////////////////////////////////////////////////////////////////////////////////
13 // Some helper functions
14 ////////////////////////////////////////////////////////////////////////////////
15 #define DISPATCH_TYPES(tensor, func)                                           \
16   {                                                                            \
17     if (query.scalar_type() == at::ScalarType::Float) {                        \
18       using scalar_t = float;                                                  \
19       func();                                                                  \
20     } else if (query.scalar_type() == at::ScalarType::Half) {                  \
21       using scalar_t = cutlass::half_t;                                        \
22       func();                                                                  \
23     } else if (query.scalar_type() == at::ScalarType::BFloat16) {              \
24       using scalar_t = cutlass::bfloat16_t;                                    \
25       func();                                                                  \
26     } else {                                                                   \
27       TORCH_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \
28     }                                                                          \
29   }
30 
31 #define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \
32   {                                         \
33     if (BOOL_V) {                           \
34       constexpr bool BOOL_NAME = true;      \
35       F();                                  \
36     } else {                                \
37       constexpr bool BOOL_NAME = false;     \
38       F();                                  \
39     }                                       \
40   }
41 #define DISPATCH_ARCHTAG(CC, func)                                        \
42   {                                                                       \
43     if (CC >= 80) {                                                       \
44       using ArchTag = cutlass::arch::Sm80;                                \
45       func();                                                             \
46     } else if (CC >= 75) {                                                \
47       using ArchTag = cutlass::arch::Sm75;                                \
48       func();                                                             \
49     } else if (CC >= 70) {                                                \
50       using ArchTag = cutlass::arch::Sm70;                                \
51       func();                                                             \
52     } else if (CC >= 50) {                                                \
53       using ArchTag = cutlass::arch::Sm50;                                \
54       func();                                                             \
55     } else {                                                              \
56       TORCH_CHECK(                                                     \
57           false,                                                          \
58           "Your device is too old. We require compute capability >= 50"); \
59     }                                                                     \
60   }
61 
62 #define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR)                            \
63   TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor");     \
64   TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
65   TORCH_CHECK(TENSOR.is_contiguous());
66 
67 #define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR)                        \
68   TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor");     \
69   TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
70   TORCH_CHECK(                                                         \
71       TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous");
72 
73 #define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
74   TORCH_CHECK(                         \
75       uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned")
76 
77 #define ASSIGN_CHECK_OVERFLOW(A, B)                                    \
78   {                                                                    \
79     A = B;                                                             \
80     TORCH_CHECK(                                                    \
81         B < std::numeric_limits<decltype(A)>::max(), #B " overflows"); \
82   }
83 
84 namespace gemm_kernel_utils {
85 
86 template <typename integer>
ceil_div(integer n,integer m)87 constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) {
88   return (n + m - 1) / m;
89 }
90 
91 template <typename integer>
align_up(integer n,integer m)92 constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) {
93   return ((n + m - 1) / m) * m;
94 }
95 
96 ////////////////////////////////////////////////////////////////////////////////
97 // Determine the type of GEMM we do (TensorCores or not, Shapes ...)
98 // TODO: Maybe we could rely on Cutlass's DefaultGemm templates
99 ////////////////////////////////////////////////////////////////////////////////
100 
101 // Fallback to Simt (FMA on cuda cores) if not in a special case below
102 template <typename ArchTag, typename scalar_t_, typename Enable = void>
103 struct DefaultGemmType {
104   static constexpr int ThreadK = 8;
105   static constexpr int WarpK = 8;
106   static constexpr int kMinimumAlignment = 1;
107   using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
108   using OpClass = cutlass::arch::OpClassSimt;
109   using Operator = cutlass::arch::OpMultiplyAdd;
110 };
111 
112 // Specialization for tensorcores with f32
113 template <typename ArchTag>
114 struct DefaultGemmType<
115     ArchTag,
116     float,
117     typename cutlass::platform::enable_if<
118         ArchTag::kMinComputeCapability >= 80>::type> {
119   static constexpr int ThreadK = 32;
120   static constexpr int WarpK = 32;
121   static constexpr int kMinimumAlignment = 4;
122   using OpClass = cutlass::arch::OpClassTensorOp;
123   using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
124   using Operator = cutlass::arch::OpMultiplyAddFastF32;
125 };
126 
127 // Specialization for tensorcores with f16/bf16 - Sm75+
128 template <typename ArchTag, typename scalar_t>
129 struct DefaultGemmType<
130     ArchTag,
131     scalar_t,
132     typename cutlass::platform::enable_if<
133         ArchTag::kMinComputeCapability >= 75 &&
134         cutlass::sizeof_bits<scalar_t>::value == 16>::type> {
135   static constexpr int ThreadK = 32;
136   static constexpr int WarpK = 32;
137   static constexpr int kMinimumAlignment = 4;
138   using OpClass = cutlass::arch::OpClassTensorOp;
139   using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
140   using Operator = cutlass::arch::OpMultiplyAdd;
141 };
142 
143 // Specialization for tensorcores with f16 - Volta
144 template <>
145 struct DefaultGemmType<cutlass::arch::Sm70, cutlass::half_t, void> {
146   static constexpr int ThreadK = 32;
147   static constexpr int WarpK = 32;
148   static constexpr int kMinimumAlignment = 2;
149   using OpClass = cutlass::arch::OpClassTensorOp;
150   using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
151   using Operator = cutlass::arch::OpMultiplyAdd;
152 };
153 
154 // Enables to do
155 // `auto x = kCondition ? fa(arg) : fb(arg)`
156 // when `fa` and `fb` have different types
157 template <bool kVal, typename TA, typename TB>
158 struct call_conditional;
159 
160 template <typename TA, typename TB>
161 struct call_conditional<true, TA, TB> {
162   template <typename Arg>
163   static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg)
164       -> decltype(ta(arg)) {
165     return ta(arg);
166   }
167 };
168 
169 template <typename TA, typename TB>
170 struct call_conditional<false, TA, TB> {
171   template <typename Arg>
172   static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg)
173       -> decltype(tb(arg)) {
174     return tb(arg);
175   }
176 };
177 
178 ////////////////////////////////////////////////////////////////////////////////
179 // Mark a variable as warp-uniform - enables some compiler optimizations
180 // The cheapest way to do it is just to broadcast it from lane 0
181 ////////////////////////////////////////////////////////////////////////////////
182 
183 template <typename T>
184 CUTLASS_DEVICE T warp_uniform(T value) {
185   struct {
186     union {
187       T value;
188       uint32_t asInt;
189     };
190   } p;
191   p.value = value;
192   p.asInt = __shfl_sync(0xffffffff, (unsigned)p.asInt, 0);
193   return p.value;
194 }
195 
196 template <typename T>
197 CUTLASS_DEVICE T* warp_uniform(T* ptr) {
198   struct {
199     union {
200       T* ptr;
201       uint32_t asInt[2];
202     };
203   } p;
204   p.ptr = ptr;
205   p.asInt[0] = warp_uniform(p.asInt[0]);
206   p.asInt[1] = warp_uniform(p.asInt[1]);
207   return p.ptr;
208 }
209 } // namespace gemm_kernel_utils
210