1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8 // This file is auto-generated. See "generate_kernels.py"
9 #pragma once
10 #include <ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h>
11 using namespace PyTorchMemEffAttention;
12 // ======== bf16 / sm80 ========
13 __global__ void __launch_bounds__(
14 AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 64, 64, 64, true, true>::kNumThreads,
15 AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 64, 64, 64, true, true>::kMinBlocksPerSm)
16 fmha_cutlassF_bf16_aligned_64x64_rf_sm80(typename AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 64, 64, 64, true, true>::Params p);
17 __global__ void __launch_bounds__(
18 AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 64, 128, 128, true, true>::kNumThreads,
19 AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 64, 128, 128, true, true>::kMinBlocksPerSm)
20 fmha_cutlassF_bf16_aligned_64x128_rf_sm80(typename AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 64, 128, 128, true, true>::Params p);
21 __global__ void __launch_bounds__(
22 AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>::kNumThreads,
23 AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>::kMinBlocksPerSm)
24 fmha_cutlassF_bf16_aligned_32x128_gmem_sm80(typename AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>::Params p);
25
dispatch_cutlassF_bf16_sm80(T cb,int cc)26 template <typename T> void dispatch_cutlassF_bf16_sm80(T cb, int cc) {
27 cb(AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 64, 64, 64, true, true>(), fmha_cutlassF_bf16_aligned_64x64_rf_sm80);
28 cb(AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 64, 128, 128, true, true>(), fmha_cutlassF_bf16_aligned_64x128_rf_sm80);
29 cb(AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>(), fmha_cutlassF_bf16_aligned_32x128_gmem_sm80);
30 }
31
32 // ======== f16 / sm50 ========
33 __global__ void __launch_bounds__(
34 AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 64, 64, 64, true, true>::kNumThreads,
35 AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 64, 64, 64, true, true>::kMinBlocksPerSm)
36 fmha_cutlassF_f16_aligned_64x64_rf_sm50(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 64, 64, 64, true, true>::Params p);
37 __global__ void __launch_bounds__(
38 AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 32, 128, 128, true, true>::kNumThreads,
39 AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 32, 128, 128, true, true>::kMinBlocksPerSm)
40 fmha_cutlassF_f16_aligned_32x128_rf_sm50(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 32, 128, 128, true, true>::Params p);
41 __global__ void __launch_bounds__(
42 AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 32, 128, 65536, true, true>::kNumThreads,
43 AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 32, 128, 65536, true, true>::kMinBlocksPerSm)
44 fmha_cutlassF_f16_aligned_32x128_gmem_sm50(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 32, 128, 65536, true, true>::Params p);
45 __global__ void __launch_bounds__(
46 AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 64, 64, 64, true, true>::kNumThreads,
47 AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 64, 64, 64, true, true>::kMinBlocksPerSm)
48 fmha_cutlassF_f16_notaligned_64x64_rf_sm50(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 64, 64, 64, true, true>::Params p);
49 __global__ void __launch_bounds__(
50 AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 32, 128, 128, true, true>::kNumThreads,
51 AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 32, 128, 128, true, true>::kMinBlocksPerSm)
52 fmha_cutlassF_f16_notaligned_32x128_rf_sm50(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 32, 128, 128, true, true>::Params p);
53 __global__ void __launch_bounds__(
54 AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 32, 128, 65536, true, true>::kNumThreads,
55 AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 32, 128, 65536, true, true>::kMinBlocksPerSm)
56 fmha_cutlassF_f16_notaligned_32x128_gmem_sm50(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 32, 128, 65536, true, true>::Params p);
57
dispatch_cutlassF_f16_sm50(T cb,int cc)58 template <typename T> void dispatch_cutlassF_f16_sm50(T cb, int cc) {
59 cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 64, 64, 64, true, true>(), fmha_cutlassF_f16_aligned_64x64_rf_sm50);
60 cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 32, 128, 128, true, true>(), fmha_cutlassF_f16_aligned_32x128_rf_sm50);
61 cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 32, 128, 65536, true, true>(), fmha_cutlassF_f16_aligned_32x128_gmem_sm50);
62 cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 64, 64, 64, true, true>(), fmha_cutlassF_f16_notaligned_64x64_rf_sm50);
63 cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 32, 128, 128, true, true>(), fmha_cutlassF_f16_notaligned_32x128_rf_sm50);
64 cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 32, 128, 65536, true, true>(), fmha_cutlassF_f16_notaligned_32x128_gmem_sm50);
65 }
66
67 // ======== f16 / sm70 ========
68 __global__ void __launch_bounds__(
69 AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 64, 64, 64, true, true>::kNumThreads,
70 AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 64, 64, 64, true, true>::kMinBlocksPerSm)
71 fmha_cutlassF_f16_aligned_64x64_rf_sm70(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 64, 64, 64, true, true>::Params p);
72 __global__ void __launch_bounds__(
73 AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 32, 128, 128, true, true>::kNumThreads,
74 AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 32, 128, 128, true, true>::kMinBlocksPerSm)
75 fmha_cutlassF_f16_aligned_32x128_rf_sm70(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 32, 128, 128, true, true>::Params p);
76 __global__ void __launch_bounds__(
77 AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 32, 128, 65536, true, true>::kNumThreads,
78 AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 32, 128, 65536, true, true>::kMinBlocksPerSm)
79 fmha_cutlassF_f16_aligned_32x128_gmem_sm70(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 32, 128, 65536, true, true>::Params p);
80 __global__ void __launch_bounds__(
81 AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 64, 64, 64, true, true>::kNumThreads,
82 AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 64, 64, 64, true, true>::kMinBlocksPerSm)
83 fmha_cutlassF_f16_notaligned_64x64_rf_sm70(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 64, 64, 64, true, true>::Params p);
84 __global__ void __launch_bounds__(
85 AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 32, 128, 128, true, true>::kNumThreads,
86 AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 32, 128, 128, true, true>::kMinBlocksPerSm)
87 fmha_cutlassF_f16_notaligned_32x128_rf_sm70(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 32, 128, 128, true, true>::Params p);
88 __global__ void __launch_bounds__(
89 AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 32, 128, 65536, true, true>::kNumThreads,
90 AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 32, 128, 65536, true, true>::kMinBlocksPerSm)
91 fmha_cutlassF_f16_notaligned_32x128_gmem_sm70(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 32, 128, 65536, true, true>::Params p);
92
dispatch_cutlassF_f16_sm70(T cb,int cc)93 template <typename T> void dispatch_cutlassF_f16_sm70(T cb, int cc) {
94 cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 64, 64, 64, true, true>(), fmha_cutlassF_f16_aligned_64x64_rf_sm70);
95 cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 32, 128, 128, true, true>(), fmha_cutlassF_f16_aligned_32x128_rf_sm70);
96 cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 32, 128, 65536, true, true>(), fmha_cutlassF_f16_aligned_32x128_gmem_sm70);
97 cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 64, 64, 64, true, true>(), fmha_cutlassF_f16_notaligned_64x64_rf_sm70);
98 cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 32, 128, 128, true, true>(), fmha_cutlassF_f16_notaligned_32x128_rf_sm70);
99 cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 32, 128, 65536, true, true>(), fmha_cutlassF_f16_notaligned_32x128_gmem_sm70);
100 }
101
102 // ======== f16 / sm75 ========
103 __global__ void __launch_bounds__(
104 AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 64, 64, 64, true, true>::kNumThreads,
105 AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 64, 64, 64, true, true>::kMinBlocksPerSm)
106 fmha_cutlassF_f16_aligned_64x64_rf_sm75(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 64, 64, 64, true, true>::Params p);
107 __global__ void __launch_bounds__(
108 AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 32, 128, 128, true, true>::kNumThreads,
109 AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 32, 128, 128, true, true>::kMinBlocksPerSm)
110 fmha_cutlassF_f16_aligned_32x128_rf_sm75(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 32, 128, 128, true, true>::Params p);
111 __global__ void __launch_bounds__(
112 AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 32, 128, 65536, true, true>::kNumThreads,
113 AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 32, 128, 65536, true, true>::kMinBlocksPerSm)
114 fmha_cutlassF_f16_aligned_32x128_gmem_sm75(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 32, 128, 65536, true, true>::Params p);
115 __global__ void __launch_bounds__(
116 AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 64, 64, 64, true, true>::kNumThreads,
117 AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 64, 64, 64, true, true>::kMinBlocksPerSm)
118 fmha_cutlassF_f16_notaligned_64x64_rf_sm75(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 64, 64, 64, true, true>::Params p);
119 __global__ void __launch_bounds__(
120 AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 32, 128, 128, true, true>::kNumThreads,
121 AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 32, 128, 128, true, true>::kMinBlocksPerSm)
122 fmha_cutlassF_f16_notaligned_32x128_rf_sm75(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 32, 128, 128, true, true>::Params p);
123 __global__ void __launch_bounds__(
124 AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 32, 128, 65536, true, true>::kNumThreads,
125 AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 32, 128, 65536, true, true>::kMinBlocksPerSm)
126 fmha_cutlassF_f16_notaligned_32x128_gmem_sm75(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 32, 128, 65536, true, true>::Params p);
127
dispatch_cutlassF_f16_sm75(T cb,int cc)128 template <typename T> void dispatch_cutlassF_f16_sm75(T cb, int cc) {
129 cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 64, 64, 64, true, true>(), fmha_cutlassF_f16_aligned_64x64_rf_sm75);
130 cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 32, 128, 128, true, true>(), fmha_cutlassF_f16_aligned_32x128_rf_sm75);
131 cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 32, 128, 65536, true, true>(), fmha_cutlassF_f16_aligned_32x128_gmem_sm75);
132 cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 64, 64, 64, true, true>(), fmha_cutlassF_f16_notaligned_64x64_rf_sm75);
133 cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 32, 128, 128, true, true>(), fmha_cutlassF_f16_notaligned_32x128_rf_sm75);
134 cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 32, 128, 65536, true, true>(), fmha_cutlassF_f16_notaligned_32x128_gmem_sm75);
135 }
136
137 // ======== f16 / sm80 ========
138 __global__ void __launch_bounds__(
139 AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 64, 64, 64, true, true>::kNumThreads,
140 AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 64, 64, 64, true, true>::kMinBlocksPerSm)
141 fmha_cutlassF_f16_aligned_64x64_rf_sm80(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 64, 64, 64, true, true>::Params p);
142 __global__ void __launch_bounds__(
143 AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 64, 128, 128, true, true>::kNumThreads,
144 AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 64, 128, 128, true, true>::kMinBlocksPerSm)
145 fmha_cutlassF_f16_aligned_64x128_rf_sm80(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 64, 128, 128, true, true>::Params p);
146 __global__ void __launch_bounds__(
147 AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>::kNumThreads,
148 AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>::kMinBlocksPerSm)
149 fmha_cutlassF_f16_aligned_32x128_gmem_sm80(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>::Params p);
150
dispatch_cutlassF_f16_sm80(T cb,int cc)151 template <typename T> void dispatch_cutlassF_f16_sm80(T cb, int cc) {
152 cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 64, 64, 64, true, true>(), fmha_cutlassF_f16_aligned_64x64_rf_sm80);
153 cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 64, 128, 128, true, true>(), fmha_cutlassF_f16_aligned_64x128_rf_sm80);
154 cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>(), fmha_cutlassF_f16_aligned_32x128_gmem_sm80);
155 }
156
157 // ======== f32 / sm50 ========
158 __global__ void __launch_bounds__(
159 AttentionKernel<float, cutlass::arch::Sm50, true, 64, 64, 64, true, true>::kNumThreads,
160 AttentionKernel<float, cutlass::arch::Sm50, true, 64, 64, 64, true, true>::kMinBlocksPerSm)
161 fmha_cutlassF_f32_aligned_64x64_rf_sm50(typename AttentionKernel<float, cutlass::arch::Sm50, true, 64, 64, 64, true, true>::Params p);
162 __global__ void __launch_bounds__(
163 AttentionKernel<float, cutlass::arch::Sm50, true, 32, 128, 128, true, true>::kNumThreads,
164 AttentionKernel<float, cutlass::arch::Sm50, true, 32, 128, 128, true, true>::kMinBlocksPerSm)
165 fmha_cutlassF_f32_aligned_32x128_rf_sm50(typename AttentionKernel<float, cutlass::arch::Sm50, true, 32, 128, 128, true, true>::Params p);
166 __global__ void __launch_bounds__(
167 AttentionKernel<float, cutlass::arch::Sm50, true, 32, 128, 65536, true, true>::kNumThreads,
168 AttentionKernel<float, cutlass::arch::Sm50, true, 32, 128, 65536, true, true>::kMinBlocksPerSm)
169 fmha_cutlassF_f32_aligned_32x128_gmem_sm50(typename AttentionKernel<float, cutlass::arch::Sm50, true, 32, 128, 65536, true, true>::Params p);
170 __global__ void __launch_bounds__(
171 AttentionKernel<float, cutlass::arch::Sm50, false, 64, 64, 64, true, true>::kNumThreads,
172 AttentionKernel<float, cutlass::arch::Sm50, false, 64, 64, 64, true, true>::kMinBlocksPerSm)
173 fmha_cutlassF_f32_notaligned_64x64_rf_sm50(typename AttentionKernel<float, cutlass::arch::Sm50, false, 64, 64, 64, true, true>::Params p);
174 __global__ void __launch_bounds__(
175 AttentionKernel<float, cutlass::arch::Sm50, false, 32, 128, 128, true, true>::kNumThreads,
176 AttentionKernel<float, cutlass::arch::Sm50, false, 32, 128, 128, true, true>::kMinBlocksPerSm)
177 fmha_cutlassF_f32_notaligned_32x128_rf_sm50(typename AttentionKernel<float, cutlass::arch::Sm50, false, 32, 128, 128, true, true>::Params p);
178 __global__ void __launch_bounds__(
179 AttentionKernel<float, cutlass::arch::Sm50, false, 32, 128, 65536, true, true>::kNumThreads,
180 AttentionKernel<float, cutlass::arch::Sm50, false, 32, 128, 65536, true, true>::kMinBlocksPerSm)
181 fmha_cutlassF_f32_notaligned_32x128_gmem_sm50(typename AttentionKernel<float, cutlass::arch::Sm50, false, 32, 128, 65536, true, true>::Params p);
182
dispatch_cutlassF_f32_sm50(T cb,int cc)183 template <typename T> void dispatch_cutlassF_f32_sm50(T cb, int cc) {
184 cb(AttentionKernel<float, cutlass::arch::Sm50, true, 64, 64, 64, true, true>(), fmha_cutlassF_f32_aligned_64x64_rf_sm50);
185 cb(AttentionKernel<float, cutlass::arch::Sm50, true, 32, 128, 128, true, true>(), fmha_cutlassF_f32_aligned_32x128_rf_sm50);
186 cb(AttentionKernel<float, cutlass::arch::Sm50, true, 32, 128, 65536, true, true>(), fmha_cutlassF_f32_aligned_32x128_gmem_sm50);
187 cb(AttentionKernel<float, cutlass::arch::Sm50, false, 64, 64, 64, true, true>(), fmha_cutlassF_f32_notaligned_64x64_rf_sm50);
188 cb(AttentionKernel<float, cutlass::arch::Sm50, false, 32, 128, 128, true, true>(), fmha_cutlassF_f32_notaligned_32x128_rf_sm50);
189 cb(AttentionKernel<float, cutlass::arch::Sm50, false, 32, 128, 65536, true, true>(), fmha_cutlassF_f32_notaligned_32x128_gmem_sm50);
190 }
191
192 // ======== f32 / sm70 ========
193 __global__ void __launch_bounds__(
194 AttentionKernel<float, cutlass::arch::Sm70, true, 64, 64, 64, true, true>::kNumThreads,
195 AttentionKernel<float, cutlass::arch::Sm70, true, 64, 64, 64, true, true>::kMinBlocksPerSm)
196 fmha_cutlassF_f32_aligned_64x64_rf_sm70(typename AttentionKernel<float, cutlass::arch::Sm70, true, 64, 64, 64, true, true>::Params p);
197 __global__ void __launch_bounds__(
198 AttentionKernel<float, cutlass::arch::Sm70, true, 32, 128, 128, true, true>::kNumThreads,
199 AttentionKernel<float, cutlass::arch::Sm70, true, 32, 128, 128, true, true>::kMinBlocksPerSm)
200 fmha_cutlassF_f32_aligned_32x128_rf_sm70(typename AttentionKernel<float, cutlass::arch::Sm70, true, 32, 128, 128, true, true>::Params p);
201 __global__ void __launch_bounds__(
202 AttentionKernel<float, cutlass::arch::Sm70, true, 32, 128, 65536, true, true>::kNumThreads,
203 AttentionKernel<float, cutlass::arch::Sm70, true, 32, 128, 65536, true, true>::kMinBlocksPerSm)
204 fmha_cutlassF_f32_aligned_32x128_gmem_sm70(typename AttentionKernel<float, cutlass::arch::Sm70, true, 32, 128, 65536, true, true>::Params p);
205 __global__ void __launch_bounds__(
206 AttentionKernel<float, cutlass::arch::Sm70, false, 64, 64, 64, true, true>::kNumThreads,
207 AttentionKernel<float, cutlass::arch::Sm70, false, 64, 64, 64, true, true>::kMinBlocksPerSm)
208 fmha_cutlassF_f32_notaligned_64x64_rf_sm70(typename AttentionKernel<float, cutlass::arch::Sm70, false, 64, 64, 64, true, true>::Params p);
209 __global__ void __launch_bounds__(
210 AttentionKernel<float, cutlass::arch::Sm70, false, 32, 128, 128, true, true>::kNumThreads,
211 AttentionKernel<float, cutlass::arch::Sm70, false, 32, 128, 128, true, true>::kMinBlocksPerSm)
212 fmha_cutlassF_f32_notaligned_32x128_rf_sm70(typename AttentionKernel<float, cutlass::arch::Sm70, false, 32, 128, 128, true, true>::Params p);
213 __global__ void __launch_bounds__(
214 AttentionKernel<float, cutlass::arch::Sm70, false, 32, 128, 65536, true, true>::kNumThreads,
215 AttentionKernel<float, cutlass::arch::Sm70, false, 32, 128, 65536, true, true>::kMinBlocksPerSm)
216 fmha_cutlassF_f32_notaligned_32x128_gmem_sm70(typename AttentionKernel<float, cutlass::arch::Sm70, false, 32, 128, 65536, true, true>::Params p);
217
dispatch_cutlassF_f32_sm70(T cb,int cc)218 template <typename T> void dispatch_cutlassF_f32_sm70(T cb, int cc) {
219 cb(AttentionKernel<float, cutlass::arch::Sm70, true, 64, 64, 64, true, true>(), fmha_cutlassF_f32_aligned_64x64_rf_sm70);
220 cb(AttentionKernel<float, cutlass::arch::Sm70, true, 32, 128, 128, true, true>(), fmha_cutlassF_f32_aligned_32x128_rf_sm70);
221 cb(AttentionKernel<float, cutlass::arch::Sm70, true, 32, 128, 65536, true, true>(), fmha_cutlassF_f32_aligned_32x128_gmem_sm70);
222 cb(AttentionKernel<float, cutlass::arch::Sm70, false, 64, 64, 64, true, true>(), fmha_cutlassF_f32_notaligned_64x64_rf_sm70);
223 cb(AttentionKernel<float, cutlass::arch::Sm70, false, 32, 128, 128, true, true>(), fmha_cutlassF_f32_notaligned_32x128_rf_sm70);
224 cb(AttentionKernel<float, cutlass::arch::Sm70, false, 32, 128, 65536, true, true>(), fmha_cutlassF_f32_notaligned_32x128_gmem_sm70);
225 }
226
227 // ======== f32 / sm75 ========
228 __global__ void __launch_bounds__(
229 AttentionKernel<float, cutlass::arch::Sm75, true, 64, 64, 64, true, true>::kNumThreads,
230 AttentionKernel<float, cutlass::arch::Sm75, true, 64, 64, 64, true, true>::kMinBlocksPerSm)
231 fmha_cutlassF_f32_aligned_64x64_rf_sm75(typename AttentionKernel<float, cutlass::arch::Sm75, true, 64, 64, 64, true, true>::Params p);
232 __global__ void __launch_bounds__(
233 AttentionKernel<float, cutlass::arch::Sm75, true, 32, 128, 128, true, true>::kNumThreads,
234 AttentionKernel<float, cutlass::arch::Sm75, true, 32, 128, 128, true, true>::kMinBlocksPerSm)
235 fmha_cutlassF_f32_aligned_32x128_rf_sm75(typename AttentionKernel<float, cutlass::arch::Sm75, true, 32, 128, 128, true, true>::Params p);
236 __global__ void __launch_bounds__(
237 AttentionKernel<float, cutlass::arch::Sm75, true, 32, 128, 65536, true, true>::kNumThreads,
238 AttentionKernel<float, cutlass::arch::Sm75, true, 32, 128, 65536, true, true>::kMinBlocksPerSm)
239 fmha_cutlassF_f32_aligned_32x128_gmem_sm75(typename AttentionKernel<float, cutlass::arch::Sm75, true, 32, 128, 65536, true, true>::Params p);
240 __global__ void __launch_bounds__(
241 AttentionKernel<float, cutlass::arch::Sm75, false, 64, 64, 64, true, true>::kNumThreads,
242 AttentionKernel<float, cutlass::arch::Sm75, false, 64, 64, 64, true, true>::kMinBlocksPerSm)
243 fmha_cutlassF_f32_notaligned_64x64_rf_sm75(typename AttentionKernel<float, cutlass::arch::Sm75, false, 64, 64, 64, true, true>::Params p);
244 __global__ void __launch_bounds__(
245 AttentionKernel<float, cutlass::arch::Sm75, false, 32, 128, 128, true, true>::kNumThreads,
246 AttentionKernel<float, cutlass::arch::Sm75, false, 32, 128, 128, true, true>::kMinBlocksPerSm)
247 fmha_cutlassF_f32_notaligned_32x128_rf_sm75(typename AttentionKernel<float, cutlass::arch::Sm75, false, 32, 128, 128, true, true>::Params p);
248 __global__ void __launch_bounds__(
249 AttentionKernel<float, cutlass::arch::Sm75, false, 32, 128, 65536, true, true>::kNumThreads,
250 AttentionKernel<float, cutlass::arch::Sm75, false, 32, 128, 65536, true, true>::kMinBlocksPerSm)
251 fmha_cutlassF_f32_notaligned_32x128_gmem_sm75(typename AttentionKernel<float, cutlass::arch::Sm75, false, 32, 128, 65536, true, true>::Params p);
252
dispatch_cutlassF_f32_sm75(T cb,int cc)253 template <typename T> void dispatch_cutlassF_f32_sm75(T cb, int cc) {
254 cb(AttentionKernel<float, cutlass::arch::Sm75, true, 64, 64, 64, true, true>(), fmha_cutlassF_f32_aligned_64x64_rf_sm75);
255 cb(AttentionKernel<float, cutlass::arch::Sm75, true, 32, 128, 128, true, true>(), fmha_cutlassF_f32_aligned_32x128_rf_sm75);
256 cb(AttentionKernel<float, cutlass::arch::Sm75, true, 32, 128, 65536, true, true>(), fmha_cutlassF_f32_aligned_32x128_gmem_sm75);
257 cb(AttentionKernel<float, cutlass::arch::Sm75, false, 64, 64, 64, true, true>(), fmha_cutlassF_f32_notaligned_64x64_rf_sm75);
258 cb(AttentionKernel<float, cutlass::arch::Sm75, false, 32, 128, 128, true, true>(), fmha_cutlassF_f32_notaligned_32x128_rf_sm75);
259 cb(AttentionKernel<float, cutlass::arch::Sm75, false, 32, 128, 65536, true, true>(), fmha_cutlassF_f32_notaligned_32x128_gmem_sm75);
260 }
261
262 // ======== f32 / sm80 ========
263 __global__ void __launch_bounds__(
264 AttentionKernel<float, cutlass::arch::Sm80, true, 64, 64, 64, true, true>::kNumThreads,
265 AttentionKernel<float, cutlass::arch::Sm80, true, 64, 64, 64, true, true>::kMinBlocksPerSm)
266 fmha_cutlassF_f32_aligned_64x64_rf_sm80(typename AttentionKernel<float, cutlass::arch::Sm80, true, 64, 64, 64, true, true>::Params p);
267 __global__ void __launch_bounds__(
268 AttentionKernel<float, cutlass::arch::Sm80, true, 64, 128, 128, true, true>::kNumThreads,
269 AttentionKernel<float, cutlass::arch::Sm80, true, 64, 128, 128, true, true>::kMinBlocksPerSm)
270 fmha_cutlassF_f32_aligned_64x128_rf_sm80(typename AttentionKernel<float, cutlass::arch::Sm80, true, 64, 128, 128, true, true>::Params p);
271 __global__ void __launch_bounds__(
272 AttentionKernel<float, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>::kNumThreads,
273 AttentionKernel<float, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>::kMinBlocksPerSm)
274 fmha_cutlassF_f32_aligned_32x128_gmem_sm80(typename AttentionKernel<float, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>::Params p);
275
dispatch_cutlassF_f32_sm80(T cb,int cc)276 template <typename T> void dispatch_cutlassF_f32_sm80(T cb, int cc) {
277 cb(AttentionKernel<float, cutlass::arch::Sm80, true, 64, 64, 64, true, true>(), fmha_cutlassF_f32_aligned_64x64_rf_sm80);
278 cb(AttentionKernel<float, cutlass::arch::Sm80, true, 64, 128, 128, true, true>(), fmha_cutlassF_f32_aligned_64x128_rf_sm80);
279 cb(AttentionKernel<float, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>(), fmha_cutlassF_f32_aligned_32x128_gmem_sm80);
280 }
281
282
283 template <typename DT, typename T>
284 void dispatch_cutlassF(T cb, int cc = 0) {
285
286 if (std::is_same<DT, cutlass::bfloat16_t>::value && 80 <= cc && cc < 100) {
287 dispatch_cutlassF_bf16_sm80(cb, cc);
288 }
289 if (std::is_same<DT, cutlass::half_t>::value && 50 <= cc && cc < 70) {
290 dispatch_cutlassF_f16_sm50(cb, cc);
291 }
292 if (std::is_same<DT, cutlass::half_t>::value && 70 <= cc && cc < 75) {
293 dispatch_cutlassF_f16_sm70(cb, cc);
294 }
295 if (std::is_same<DT, cutlass::half_t>::value && 75 <= cc && cc < 80) {
296 dispatch_cutlassF_f16_sm75(cb, cc);
297 }
298 if (std::is_same<DT, cutlass::half_t>::value && 80 <= cc && cc < 100) {
299 dispatch_cutlassF_f16_sm80(cb, cc);
300 }
301 if (std::is_same<DT, float>::value && 50 <= cc && cc < 70) {
302 dispatch_cutlassF_f32_sm50(cb, cc);
303 }
304 if (std::is_same<DT, float>::value && 70 <= cc && cc < 75) {
305 dispatch_cutlassF_f32_sm70(cb, cc);
306 }
307 if (std::is_same<DT, float>::value && 75 <= cc && cc < 80) {
308 dispatch_cutlassF_f32_sm75(cb, cc);
309 }
310 if (std::is_same<DT, float>::value && 80 <= cc && cc < 100) {
311 dispatch_cutlassF_f32_sm80(cb, cc);
312 }
313 }
314