1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8 #pragma once
9
10 #include <cutlass/arch/mma.h>
11
12 ////////////////////////////////////////////////////////////////////////////////
13 // Some helper functions
14 ////////////////////////////////////////////////////////////////////////////////
15 #define DISPATCH_TYPES(tensor, func) \
16 { \
17 if (query.scalar_type() == at::ScalarType::Float) { \
18 using scalar_t = float; \
19 func(); \
20 } else if (query.scalar_type() == at::ScalarType::Half) { \
21 using scalar_t = cutlass::half_t; \
22 func(); \
23 } else if (query.scalar_type() == at::ScalarType::BFloat16) { \
24 using scalar_t = cutlass::bfloat16_t; \
25 func(); \
26 } else { \
27 TORCH_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \
28 } \
29 }
30
31 #define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \
32 { \
33 if (BOOL_V) { \
34 constexpr bool BOOL_NAME = true; \
35 F(); \
36 } else { \
37 constexpr bool BOOL_NAME = false; \
38 F(); \
39 } \
40 }
41 #define DISPATCH_ARCHTAG(CC, func) \
42 { \
43 if (CC >= 80) { \
44 using ArchTag = cutlass::arch::Sm80; \
45 func(); \
46 } else if (CC >= 75) { \
47 using ArchTag = cutlass::arch::Sm75; \
48 func(); \
49 } else if (CC >= 70) { \
50 using ArchTag = cutlass::arch::Sm70; \
51 func(); \
52 } else if (CC >= 50) { \
53 using ArchTag = cutlass::arch::Sm50; \
54 func(); \
55 } else { \
56 TORCH_CHECK( \
57 false, \
58 "Your device is too old. We require compute capability >= 50"); \
59 } \
60 }
61
62 #define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \
63 TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
64 TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
65 TORCH_CHECK(TENSOR.is_contiguous());
66
67 #define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \
68 TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
69 TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
70 TORCH_CHECK( \
71 TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous");
72
73 #define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
74 TORCH_CHECK( \
75 uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned")
76
77 #define ASSIGN_CHECK_OVERFLOW(A, B) \
78 { \
79 A = B; \
80 TORCH_CHECK( \
81 B < std::numeric_limits<decltype(A)>::max(), #B " overflows"); \
82 }
83
84 namespace gemm_kernel_utils {
85
86 template <typename integer>
ceil_div(integer n,integer m)87 constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) {
88 return (n + m - 1) / m;
89 }
90
91 template <typename integer>
align_up(integer n,integer m)92 constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) {
93 return ((n + m - 1) / m) * m;
94 }
95
96 ////////////////////////////////////////////////////////////////////////////////
97 // Determine the type of GEMM we do (TensorCores or not, Shapes ...)
98 // TODO: Maybe we could rely on Cutlass's DefaultGemm templates
99 ////////////////////////////////////////////////////////////////////////////////
100
101 // Fallback to Simt (FMA on cuda cores) if not in a special case below
102 template <typename ArchTag, typename scalar_t_, typename Enable = void>
103 struct DefaultGemmType {
104 static constexpr int ThreadK = 8;
105 static constexpr int WarpK = 8;
106 static constexpr int kMinimumAlignment = 1;
107 using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
108 using OpClass = cutlass::arch::OpClassSimt;
109 using Operator = cutlass::arch::OpMultiplyAdd;
110 };
111
112 // Specialization for tensorcores with f32
113 template <typename ArchTag>
114 struct DefaultGemmType<
115 ArchTag,
116 float,
117 typename cutlass::platform::enable_if<
118 ArchTag::kMinComputeCapability >= 80>::type> {
119 static constexpr int ThreadK = 32;
120 static constexpr int WarpK = 32;
121 static constexpr int kMinimumAlignment = 4;
122 using OpClass = cutlass::arch::OpClassTensorOp;
123 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
124 using Operator = cutlass::arch::OpMultiplyAddFastF32;
125 };
126
127 // Specialization for tensorcores with f16/bf16 - Sm75+
128 template <typename ArchTag, typename scalar_t>
129 struct DefaultGemmType<
130 ArchTag,
131 scalar_t,
132 typename cutlass::platform::enable_if<
133 ArchTag::kMinComputeCapability >= 75 &&
134 cutlass::sizeof_bits<scalar_t>::value == 16>::type> {
135 static constexpr int ThreadK = 32;
136 static constexpr int WarpK = 32;
137 static constexpr int kMinimumAlignment = 4;
138 using OpClass = cutlass::arch::OpClassTensorOp;
139 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
140 using Operator = cutlass::arch::OpMultiplyAdd;
141 };
142
143 // Specialization for tensorcores with f16 - Volta
144 template <>
145 struct DefaultGemmType<cutlass::arch::Sm70, cutlass::half_t, void> {
146 static constexpr int ThreadK = 32;
147 static constexpr int WarpK = 32;
148 static constexpr int kMinimumAlignment = 2;
149 using OpClass = cutlass::arch::OpClassTensorOp;
150 using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
151 using Operator = cutlass::arch::OpMultiplyAdd;
152 };
153
154 // Enables to do
155 // `auto x = kCondition ? fa(arg) : fb(arg)`
156 // when `fa` and `fb` have different types
157 template <bool kVal, typename TA, typename TB>
158 struct call_conditional;
159
160 template <typename TA, typename TB>
161 struct call_conditional<true, TA, TB> {
162 template <typename Arg>
163 static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg)
164 -> decltype(ta(arg)) {
165 return ta(arg);
166 }
167 };
168
169 template <typename TA, typename TB>
170 struct call_conditional<false, TA, TB> {
171 template <typename Arg>
172 static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg)
173 -> decltype(tb(arg)) {
174 return tb(arg);
175 }
176 };
177
178 ////////////////////////////////////////////////////////////////////////////////
179 // Mark a variable as warp-uniform - enables some compiler optimizations
180 // The cheapest way to do it is just to broadcast it from lane 0
181 ////////////////////////////////////////////////////////////////////////////////
182
183 template <typename T>
184 CUTLASS_DEVICE T warp_uniform(T value) {
185 struct {
186 union {
187 T value;
188 uint32_t asInt;
189 };
190 } p;
191 p.value = value;
192 p.asInt = __shfl_sync(0xffffffff, (unsigned)p.asInt, 0);
193 return p.value;
194 }
195
196 template <typename T>
197 CUTLASS_DEVICE T* warp_uniform(T* ptr) {
198 struct {
199 union {
200 T* ptr;
201 uint32_t asInt[2];
202 };
203 } p;
204 p.ptr = ptr;
205 p.asInt[0] = warp_uniform(p.asInt[0]);
206 p.asInt[1] = warp_uniform(p.asInt[1]);
207 return p.ptr;
208 }
209 } // namespace gemm_kernel_utils
210