xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 // This file is auto-generated. See "generate_kernels.py"
9 #pragma once
10 #include <ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h>
11 using namespace PyTorchMemEffAttention;
12 // ======== bf16 / sm80 ========
13 __global__ void __launch_bounds__(
14     AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 64, 64, 64, true, true>::kNumThreads,
15     AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 64, 64, 64, true, true>::kMinBlocksPerSm)
16 fmha_cutlassF_bf16_aligned_64x64_rf_sm80(typename AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 64, 64, 64, true, true>::Params p);
17 __global__ void __launch_bounds__(
18     AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 64, 128, 128, true, true>::kNumThreads,
19     AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 64, 128, 128, true, true>::kMinBlocksPerSm)
20 fmha_cutlassF_bf16_aligned_64x128_rf_sm80(typename AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 64, 128, 128, true, true>::Params p);
21 __global__ void __launch_bounds__(
22     AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>::kNumThreads,
23     AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>::kMinBlocksPerSm)
24 fmha_cutlassF_bf16_aligned_32x128_gmem_sm80(typename AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>::Params p);
25 
dispatch_cutlassF_bf16_sm80(T cb,int cc)26 template <typename T> void dispatch_cutlassF_bf16_sm80(T cb, int cc) {
27     cb(AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 64, 64, 64, true, true>(), fmha_cutlassF_bf16_aligned_64x64_rf_sm80);
28     cb(AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 64, 128, 128, true, true>(), fmha_cutlassF_bf16_aligned_64x128_rf_sm80);
29     cb(AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>(), fmha_cutlassF_bf16_aligned_32x128_gmem_sm80);
30 }
31 
32 // ======== f16 / sm50 ========
33 __global__ void __launch_bounds__(
34     AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 64, 64, 64, true, true>::kNumThreads,
35     AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 64, 64, 64, true, true>::kMinBlocksPerSm)
36 fmha_cutlassF_f16_aligned_64x64_rf_sm50(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 64, 64, 64, true, true>::Params p);
37 __global__ void __launch_bounds__(
38     AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 32, 128, 128, true, true>::kNumThreads,
39     AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 32, 128, 128, true, true>::kMinBlocksPerSm)
40 fmha_cutlassF_f16_aligned_32x128_rf_sm50(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 32, 128, 128, true, true>::Params p);
41 __global__ void __launch_bounds__(
42     AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 32, 128, 65536, true, true>::kNumThreads,
43     AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 32, 128, 65536, true, true>::kMinBlocksPerSm)
44 fmha_cutlassF_f16_aligned_32x128_gmem_sm50(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 32, 128, 65536, true, true>::Params p);
45 __global__ void __launch_bounds__(
46     AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 64, 64, 64, true, true>::kNumThreads,
47     AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 64, 64, 64, true, true>::kMinBlocksPerSm)
48 fmha_cutlassF_f16_notaligned_64x64_rf_sm50(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 64, 64, 64, true, true>::Params p);
49 __global__ void __launch_bounds__(
50     AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 32, 128, 128, true, true>::kNumThreads,
51     AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 32, 128, 128, true, true>::kMinBlocksPerSm)
52 fmha_cutlassF_f16_notaligned_32x128_rf_sm50(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 32, 128, 128, true, true>::Params p);
53 __global__ void __launch_bounds__(
54     AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 32, 128, 65536, true, true>::kNumThreads,
55     AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 32, 128, 65536, true, true>::kMinBlocksPerSm)
56 fmha_cutlassF_f16_notaligned_32x128_gmem_sm50(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 32, 128, 65536, true, true>::Params p);
57 
dispatch_cutlassF_f16_sm50(T cb,int cc)58 template <typename T> void dispatch_cutlassF_f16_sm50(T cb, int cc) {
59     cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 64, 64, 64, true, true>(), fmha_cutlassF_f16_aligned_64x64_rf_sm50);
60     cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 32, 128, 128, true, true>(), fmha_cutlassF_f16_aligned_32x128_rf_sm50);
61     cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, true, 32, 128, 65536, true, true>(), fmha_cutlassF_f16_aligned_32x128_gmem_sm50);
62     cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 64, 64, 64, true, true>(), fmha_cutlassF_f16_notaligned_64x64_rf_sm50);
63     cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 32, 128, 128, true, true>(), fmha_cutlassF_f16_notaligned_32x128_rf_sm50);
64     cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm50, false, 32, 128, 65536, true, true>(), fmha_cutlassF_f16_notaligned_32x128_gmem_sm50);
65 }
66 
67 // ======== f16 / sm70 ========
68 __global__ void __launch_bounds__(
69     AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 64, 64, 64, true, true>::kNumThreads,
70     AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 64, 64, 64, true, true>::kMinBlocksPerSm)
71 fmha_cutlassF_f16_aligned_64x64_rf_sm70(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 64, 64, 64, true, true>::Params p);
72 __global__ void __launch_bounds__(
73     AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 32, 128, 128, true, true>::kNumThreads,
74     AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 32, 128, 128, true, true>::kMinBlocksPerSm)
75 fmha_cutlassF_f16_aligned_32x128_rf_sm70(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 32, 128, 128, true, true>::Params p);
76 __global__ void __launch_bounds__(
77     AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 32, 128, 65536, true, true>::kNumThreads,
78     AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 32, 128, 65536, true, true>::kMinBlocksPerSm)
79 fmha_cutlassF_f16_aligned_32x128_gmem_sm70(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 32, 128, 65536, true, true>::Params p);
80 __global__ void __launch_bounds__(
81     AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 64, 64, 64, true, true>::kNumThreads,
82     AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 64, 64, 64, true, true>::kMinBlocksPerSm)
83 fmha_cutlassF_f16_notaligned_64x64_rf_sm70(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 64, 64, 64, true, true>::Params p);
84 __global__ void __launch_bounds__(
85     AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 32, 128, 128, true, true>::kNumThreads,
86     AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 32, 128, 128, true, true>::kMinBlocksPerSm)
87 fmha_cutlassF_f16_notaligned_32x128_rf_sm70(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 32, 128, 128, true, true>::Params p);
88 __global__ void __launch_bounds__(
89     AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 32, 128, 65536, true, true>::kNumThreads,
90     AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 32, 128, 65536, true, true>::kMinBlocksPerSm)
91 fmha_cutlassF_f16_notaligned_32x128_gmem_sm70(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 32, 128, 65536, true, true>::Params p);
92 
dispatch_cutlassF_f16_sm70(T cb,int cc)93 template <typename T> void dispatch_cutlassF_f16_sm70(T cb, int cc) {
94     cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 64, 64, 64, true, true>(), fmha_cutlassF_f16_aligned_64x64_rf_sm70);
95     cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 32, 128, 128, true, true>(), fmha_cutlassF_f16_aligned_32x128_rf_sm70);
96     cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, true, 32, 128, 65536, true, true>(), fmha_cutlassF_f16_aligned_32x128_gmem_sm70);
97     cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 64, 64, 64, true, true>(), fmha_cutlassF_f16_notaligned_64x64_rf_sm70);
98     cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 32, 128, 128, true, true>(), fmha_cutlassF_f16_notaligned_32x128_rf_sm70);
99     cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm70, false, 32, 128, 65536, true, true>(), fmha_cutlassF_f16_notaligned_32x128_gmem_sm70);
100 }
101 
102 // ======== f16 / sm75 ========
103 __global__ void __launch_bounds__(
104     AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 64, 64, 64, true, true>::kNumThreads,
105     AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 64, 64, 64, true, true>::kMinBlocksPerSm)
106 fmha_cutlassF_f16_aligned_64x64_rf_sm75(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 64, 64, 64, true, true>::Params p);
107 __global__ void __launch_bounds__(
108     AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 32, 128, 128, true, true>::kNumThreads,
109     AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 32, 128, 128, true, true>::kMinBlocksPerSm)
110 fmha_cutlassF_f16_aligned_32x128_rf_sm75(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 32, 128, 128, true, true>::Params p);
111 __global__ void __launch_bounds__(
112     AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 32, 128, 65536, true, true>::kNumThreads,
113     AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 32, 128, 65536, true, true>::kMinBlocksPerSm)
114 fmha_cutlassF_f16_aligned_32x128_gmem_sm75(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 32, 128, 65536, true, true>::Params p);
115 __global__ void __launch_bounds__(
116     AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 64, 64, 64, true, true>::kNumThreads,
117     AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 64, 64, 64, true, true>::kMinBlocksPerSm)
118 fmha_cutlassF_f16_notaligned_64x64_rf_sm75(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 64, 64, 64, true, true>::Params p);
119 __global__ void __launch_bounds__(
120     AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 32, 128, 128, true, true>::kNumThreads,
121     AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 32, 128, 128, true, true>::kMinBlocksPerSm)
122 fmha_cutlassF_f16_notaligned_32x128_rf_sm75(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 32, 128, 128, true, true>::Params p);
123 __global__ void __launch_bounds__(
124     AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 32, 128, 65536, true, true>::kNumThreads,
125     AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 32, 128, 65536, true, true>::kMinBlocksPerSm)
126 fmha_cutlassF_f16_notaligned_32x128_gmem_sm75(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 32, 128, 65536, true, true>::Params p);
127 
dispatch_cutlassF_f16_sm75(T cb,int cc)128 template <typename T> void dispatch_cutlassF_f16_sm75(T cb, int cc) {
129     cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 64, 64, 64, true, true>(), fmha_cutlassF_f16_aligned_64x64_rf_sm75);
130     cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 32, 128, 128, true, true>(), fmha_cutlassF_f16_aligned_32x128_rf_sm75);
131     cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, true, 32, 128, 65536, true, true>(), fmha_cutlassF_f16_aligned_32x128_gmem_sm75);
132     cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 64, 64, 64, true, true>(), fmha_cutlassF_f16_notaligned_64x64_rf_sm75);
133     cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 32, 128, 128, true, true>(), fmha_cutlassF_f16_notaligned_32x128_rf_sm75);
134     cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm75, false, 32, 128, 65536, true, true>(), fmha_cutlassF_f16_notaligned_32x128_gmem_sm75);
135 }
136 
137 // ======== f16 / sm80 ========
138 __global__ void __launch_bounds__(
139     AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 64, 64, 64, true, true>::kNumThreads,
140     AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 64, 64, 64, true, true>::kMinBlocksPerSm)
141 fmha_cutlassF_f16_aligned_64x64_rf_sm80(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 64, 64, 64, true, true>::Params p);
142 __global__ void __launch_bounds__(
143     AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 64, 128, 128, true, true>::kNumThreads,
144     AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 64, 128, 128, true, true>::kMinBlocksPerSm)
145 fmha_cutlassF_f16_aligned_64x128_rf_sm80(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 64, 128, 128, true, true>::Params p);
146 __global__ void __launch_bounds__(
147     AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>::kNumThreads,
148     AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>::kMinBlocksPerSm)
149 fmha_cutlassF_f16_aligned_32x128_gmem_sm80(typename AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>::Params p);
150 
dispatch_cutlassF_f16_sm80(T cb,int cc)151 template <typename T> void dispatch_cutlassF_f16_sm80(T cb, int cc) {
152     cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 64, 64, 64, true, true>(), fmha_cutlassF_f16_aligned_64x64_rf_sm80);
153     cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 64, 128, 128, true, true>(), fmha_cutlassF_f16_aligned_64x128_rf_sm80);
154     cb(AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>(), fmha_cutlassF_f16_aligned_32x128_gmem_sm80);
155 }
156 
157 // ======== f32 / sm50 ========
158 __global__ void __launch_bounds__(
159     AttentionKernel<float, cutlass::arch::Sm50, true, 64, 64, 64, true, true>::kNumThreads,
160     AttentionKernel<float, cutlass::arch::Sm50, true, 64, 64, 64, true, true>::kMinBlocksPerSm)
161 fmha_cutlassF_f32_aligned_64x64_rf_sm50(typename AttentionKernel<float, cutlass::arch::Sm50, true, 64, 64, 64, true, true>::Params p);
162 __global__ void __launch_bounds__(
163     AttentionKernel<float, cutlass::arch::Sm50, true, 32, 128, 128, true, true>::kNumThreads,
164     AttentionKernel<float, cutlass::arch::Sm50, true, 32, 128, 128, true, true>::kMinBlocksPerSm)
165 fmha_cutlassF_f32_aligned_32x128_rf_sm50(typename AttentionKernel<float, cutlass::arch::Sm50, true, 32, 128, 128, true, true>::Params p);
166 __global__ void __launch_bounds__(
167     AttentionKernel<float, cutlass::arch::Sm50, true, 32, 128, 65536, true, true>::kNumThreads,
168     AttentionKernel<float, cutlass::arch::Sm50, true, 32, 128, 65536, true, true>::kMinBlocksPerSm)
169 fmha_cutlassF_f32_aligned_32x128_gmem_sm50(typename AttentionKernel<float, cutlass::arch::Sm50, true, 32, 128, 65536, true, true>::Params p);
170 __global__ void __launch_bounds__(
171     AttentionKernel<float, cutlass::arch::Sm50, false, 64, 64, 64, true, true>::kNumThreads,
172     AttentionKernel<float, cutlass::arch::Sm50, false, 64, 64, 64, true, true>::kMinBlocksPerSm)
173 fmha_cutlassF_f32_notaligned_64x64_rf_sm50(typename AttentionKernel<float, cutlass::arch::Sm50, false, 64, 64, 64, true, true>::Params p);
174 __global__ void __launch_bounds__(
175     AttentionKernel<float, cutlass::arch::Sm50, false, 32, 128, 128, true, true>::kNumThreads,
176     AttentionKernel<float, cutlass::arch::Sm50, false, 32, 128, 128, true, true>::kMinBlocksPerSm)
177 fmha_cutlassF_f32_notaligned_32x128_rf_sm50(typename AttentionKernel<float, cutlass::arch::Sm50, false, 32, 128, 128, true, true>::Params p);
178 __global__ void __launch_bounds__(
179     AttentionKernel<float, cutlass::arch::Sm50, false, 32, 128, 65536, true, true>::kNumThreads,
180     AttentionKernel<float, cutlass::arch::Sm50, false, 32, 128, 65536, true, true>::kMinBlocksPerSm)
181 fmha_cutlassF_f32_notaligned_32x128_gmem_sm50(typename AttentionKernel<float, cutlass::arch::Sm50, false, 32, 128, 65536, true, true>::Params p);
182 
dispatch_cutlassF_f32_sm50(T cb,int cc)183 template <typename T> void dispatch_cutlassF_f32_sm50(T cb, int cc) {
184     cb(AttentionKernel<float, cutlass::arch::Sm50, true, 64, 64, 64, true, true>(), fmha_cutlassF_f32_aligned_64x64_rf_sm50);
185     cb(AttentionKernel<float, cutlass::arch::Sm50, true, 32, 128, 128, true, true>(), fmha_cutlassF_f32_aligned_32x128_rf_sm50);
186     cb(AttentionKernel<float, cutlass::arch::Sm50, true, 32, 128, 65536, true, true>(), fmha_cutlassF_f32_aligned_32x128_gmem_sm50);
187     cb(AttentionKernel<float, cutlass::arch::Sm50, false, 64, 64, 64, true, true>(), fmha_cutlassF_f32_notaligned_64x64_rf_sm50);
188     cb(AttentionKernel<float, cutlass::arch::Sm50, false, 32, 128, 128, true, true>(), fmha_cutlassF_f32_notaligned_32x128_rf_sm50);
189     cb(AttentionKernel<float, cutlass::arch::Sm50, false, 32, 128, 65536, true, true>(), fmha_cutlassF_f32_notaligned_32x128_gmem_sm50);
190 }
191 
192 // ======== f32 / sm70 ========
193 __global__ void __launch_bounds__(
194     AttentionKernel<float, cutlass::arch::Sm70, true, 64, 64, 64, true, true>::kNumThreads,
195     AttentionKernel<float, cutlass::arch::Sm70, true, 64, 64, 64, true, true>::kMinBlocksPerSm)
196 fmha_cutlassF_f32_aligned_64x64_rf_sm70(typename AttentionKernel<float, cutlass::arch::Sm70, true, 64, 64, 64, true, true>::Params p);
197 __global__ void __launch_bounds__(
198     AttentionKernel<float, cutlass::arch::Sm70, true, 32, 128, 128, true, true>::kNumThreads,
199     AttentionKernel<float, cutlass::arch::Sm70, true, 32, 128, 128, true, true>::kMinBlocksPerSm)
200 fmha_cutlassF_f32_aligned_32x128_rf_sm70(typename AttentionKernel<float, cutlass::arch::Sm70, true, 32, 128, 128, true, true>::Params p);
201 __global__ void __launch_bounds__(
202     AttentionKernel<float, cutlass::arch::Sm70, true, 32, 128, 65536, true, true>::kNumThreads,
203     AttentionKernel<float, cutlass::arch::Sm70, true, 32, 128, 65536, true, true>::kMinBlocksPerSm)
204 fmha_cutlassF_f32_aligned_32x128_gmem_sm70(typename AttentionKernel<float, cutlass::arch::Sm70, true, 32, 128, 65536, true, true>::Params p);
205 __global__ void __launch_bounds__(
206     AttentionKernel<float, cutlass::arch::Sm70, false, 64, 64, 64, true, true>::kNumThreads,
207     AttentionKernel<float, cutlass::arch::Sm70, false, 64, 64, 64, true, true>::kMinBlocksPerSm)
208 fmha_cutlassF_f32_notaligned_64x64_rf_sm70(typename AttentionKernel<float, cutlass::arch::Sm70, false, 64, 64, 64, true, true>::Params p);
209 __global__ void __launch_bounds__(
210     AttentionKernel<float, cutlass::arch::Sm70, false, 32, 128, 128, true, true>::kNumThreads,
211     AttentionKernel<float, cutlass::arch::Sm70, false, 32, 128, 128, true, true>::kMinBlocksPerSm)
212 fmha_cutlassF_f32_notaligned_32x128_rf_sm70(typename AttentionKernel<float, cutlass::arch::Sm70, false, 32, 128, 128, true, true>::Params p);
213 __global__ void __launch_bounds__(
214     AttentionKernel<float, cutlass::arch::Sm70, false, 32, 128, 65536, true, true>::kNumThreads,
215     AttentionKernel<float, cutlass::arch::Sm70, false, 32, 128, 65536, true, true>::kMinBlocksPerSm)
216 fmha_cutlassF_f32_notaligned_32x128_gmem_sm70(typename AttentionKernel<float, cutlass::arch::Sm70, false, 32, 128, 65536, true, true>::Params p);
217 
dispatch_cutlassF_f32_sm70(T cb,int cc)218 template <typename T> void dispatch_cutlassF_f32_sm70(T cb, int cc) {
219     cb(AttentionKernel<float, cutlass::arch::Sm70, true, 64, 64, 64, true, true>(), fmha_cutlassF_f32_aligned_64x64_rf_sm70);
220     cb(AttentionKernel<float, cutlass::arch::Sm70, true, 32, 128, 128, true, true>(), fmha_cutlassF_f32_aligned_32x128_rf_sm70);
221     cb(AttentionKernel<float, cutlass::arch::Sm70, true, 32, 128, 65536, true, true>(), fmha_cutlassF_f32_aligned_32x128_gmem_sm70);
222     cb(AttentionKernel<float, cutlass::arch::Sm70, false, 64, 64, 64, true, true>(), fmha_cutlassF_f32_notaligned_64x64_rf_sm70);
223     cb(AttentionKernel<float, cutlass::arch::Sm70, false, 32, 128, 128, true, true>(), fmha_cutlassF_f32_notaligned_32x128_rf_sm70);
224     cb(AttentionKernel<float, cutlass::arch::Sm70, false, 32, 128, 65536, true, true>(), fmha_cutlassF_f32_notaligned_32x128_gmem_sm70);
225 }
226 
227 // ======== f32 / sm75 ========
228 __global__ void __launch_bounds__(
229     AttentionKernel<float, cutlass::arch::Sm75, true, 64, 64, 64, true, true>::kNumThreads,
230     AttentionKernel<float, cutlass::arch::Sm75, true, 64, 64, 64, true, true>::kMinBlocksPerSm)
231 fmha_cutlassF_f32_aligned_64x64_rf_sm75(typename AttentionKernel<float, cutlass::arch::Sm75, true, 64, 64, 64, true, true>::Params p);
232 __global__ void __launch_bounds__(
233     AttentionKernel<float, cutlass::arch::Sm75, true, 32, 128, 128, true, true>::kNumThreads,
234     AttentionKernel<float, cutlass::arch::Sm75, true, 32, 128, 128, true, true>::kMinBlocksPerSm)
235 fmha_cutlassF_f32_aligned_32x128_rf_sm75(typename AttentionKernel<float, cutlass::arch::Sm75, true, 32, 128, 128, true, true>::Params p);
236 __global__ void __launch_bounds__(
237     AttentionKernel<float, cutlass::arch::Sm75, true, 32, 128, 65536, true, true>::kNumThreads,
238     AttentionKernel<float, cutlass::arch::Sm75, true, 32, 128, 65536, true, true>::kMinBlocksPerSm)
239 fmha_cutlassF_f32_aligned_32x128_gmem_sm75(typename AttentionKernel<float, cutlass::arch::Sm75, true, 32, 128, 65536, true, true>::Params p);
240 __global__ void __launch_bounds__(
241     AttentionKernel<float, cutlass::arch::Sm75, false, 64, 64, 64, true, true>::kNumThreads,
242     AttentionKernel<float, cutlass::arch::Sm75, false, 64, 64, 64, true, true>::kMinBlocksPerSm)
243 fmha_cutlassF_f32_notaligned_64x64_rf_sm75(typename AttentionKernel<float, cutlass::arch::Sm75, false, 64, 64, 64, true, true>::Params p);
244 __global__ void __launch_bounds__(
245     AttentionKernel<float, cutlass::arch::Sm75, false, 32, 128, 128, true, true>::kNumThreads,
246     AttentionKernel<float, cutlass::arch::Sm75, false, 32, 128, 128, true, true>::kMinBlocksPerSm)
247 fmha_cutlassF_f32_notaligned_32x128_rf_sm75(typename AttentionKernel<float, cutlass::arch::Sm75, false, 32, 128, 128, true, true>::Params p);
248 __global__ void __launch_bounds__(
249     AttentionKernel<float, cutlass::arch::Sm75, false, 32, 128, 65536, true, true>::kNumThreads,
250     AttentionKernel<float, cutlass::arch::Sm75, false, 32, 128, 65536, true, true>::kMinBlocksPerSm)
251 fmha_cutlassF_f32_notaligned_32x128_gmem_sm75(typename AttentionKernel<float, cutlass::arch::Sm75, false, 32, 128, 65536, true, true>::Params p);
252 
dispatch_cutlassF_f32_sm75(T cb,int cc)253 template <typename T> void dispatch_cutlassF_f32_sm75(T cb, int cc) {
254     cb(AttentionKernel<float, cutlass::arch::Sm75, true, 64, 64, 64, true, true>(), fmha_cutlassF_f32_aligned_64x64_rf_sm75);
255     cb(AttentionKernel<float, cutlass::arch::Sm75, true, 32, 128, 128, true, true>(), fmha_cutlassF_f32_aligned_32x128_rf_sm75);
256     cb(AttentionKernel<float, cutlass::arch::Sm75, true, 32, 128, 65536, true, true>(), fmha_cutlassF_f32_aligned_32x128_gmem_sm75);
257     cb(AttentionKernel<float, cutlass::arch::Sm75, false, 64, 64, 64, true, true>(), fmha_cutlassF_f32_notaligned_64x64_rf_sm75);
258     cb(AttentionKernel<float, cutlass::arch::Sm75, false, 32, 128, 128, true, true>(), fmha_cutlassF_f32_notaligned_32x128_rf_sm75);
259     cb(AttentionKernel<float, cutlass::arch::Sm75, false, 32, 128, 65536, true, true>(), fmha_cutlassF_f32_notaligned_32x128_gmem_sm75);
260 }
261 
262 // ======== f32 / sm80 ========
263 __global__ void __launch_bounds__(
264     AttentionKernel<float, cutlass::arch::Sm80, true, 64, 64, 64, true, true>::kNumThreads,
265     AttentionKernel<float, cutlass::arch::Sm80, true, 64, 64, 64, true, true>::kMinBlocksPerSm)
266 fmha_cutlassF_f32_aligned_64x64_rf_sm80(typename AttentionKernel<float, cutlass::arch::Sm80, true, 64, 64, 64, true, true>::Params p);
267 __global__ void __launch_bounds__(
268     AttentionKernel<float, cutlass::arch::Sm80, true, 64, 128, 128, true, true>::kNumThreads,
269     AttentionKernel<float, cutlass::arch::Sm80, true, 64, 128, 128, true, true>::kMinBlocksPerSm)
270 fmha_cutlassF_f32_aligned_64x128_rf_sm80(typename AttentionKernel<float, cutlass::arch::Sm80, true, 64, 128, 128, true, true>::Params p);
271 __global__ void __launch_bounds__(
272     AttentionKernel<float, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>::kNumThreads,
273     AttentionKernel<float, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>::kMinBlocksPerSm)
274 fmha_cutlassF_f32_aligned_32x128_gmem_sm80(typename AttentionKernel<float, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>::Params p);
275 
dispatch_cutlassF_f32_sm80(T cb,int cc)276 template <typename T> void dispatch_cutlassF_f32_sm80(T cb, int cc) {
277     cb(AttentionKernel<float, cutlass::arch::Sm80, true, 64, 64, 64, true, true>(), fmha_cutlassF_f32_aligned_64x64_rf_sm80);
278     cb(AttentionKernel<float, cutlass::arch::Sm80, true, 64, 128, 128, true, true>(), fmha_cutlassF_f32_aligned_64x128_rf_sm80);
279     cb(AttentionKernel<float, cutlass::arch::Sm80, true, 32, 128, 65536, true, true>(), fmha_cutlassF_f32_aligned_32x128_gmem_sm80);
280 }
281 
282 
283 template <typename DT, typename T>
284 void dispatch_cutlassF(T cb, int cc = 0) {
285 
286     if (std::is_same<DT, cutlass::bfloat16_t>::value && 80 <= cc && cc < 100) {
287         dispatch_cutlassF_bf16_sm80(cb, cc);
288     }
289     if (std::is_same<DT, cutlass::half_t>::value && 50 <= cc && cc < 70) {
290         dispatch_cutlassF_f16_sm50(cb, cc);
291     }
292     if (std::is_same<DT, cutlass::half_t>::value && 70 <= cc && cc < 75) {
293         dispatch_cutlassF_f16_sm70(cb, cc);
294     }
295     if (std::is_same<DT, cutlass::half_t>::value && 75 <= cc && cc < 80) {
296         dispatch_cutlassF_f16_sm75(cb, cc);
297     }
298     if (std::is_same<DT, cutlass::half_t>::value && 80 <= cc && cc < 100) {
299         dispatch_cutlassF_f16_sm80(cb, cc);
300     }
301     if (std::is_same<DT, float>::value && 50 <= cc && cc < 70) {
302         dispatch_cutlassF_f32_sm50(cb, cc);
303     }
304     if (std::is_same<DT, float>::value && 70 <= cc && cc < 75) {
305         dispatch_cutlassF_f32_sm70(cb, cc);
306     }
307     if (std::is_same<DT, float>::value && 75 <= cc && cc < 80) {
308         dispatch_cutlassF_f32_sm75(cb, cc);
309     }
310     if (std::is_same<DT, float>::value && 80 <= cc && cc < 100) {
311         dispatch_cutlassF_f32_sm80(cb, cc);
312     }
313 }
314