1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <type_traits>
3
4 #include <ATen/core/Tensor.h>
5 #include <ATen/AccumulateType.h>
6 #include <ATen/Dispatch.h>
7 #include <ATen/native/DispatchStub.h>
8 #include <ATen/NestedTensorImpl.h>
9 #include <ATen/TensorAccessor.h>
10 #include <ATen/TensorOperators.h>
11 #include <c10/util/Logging.h>
12 #include <c10/util/bit_cast.h>
13
14 #include <ATen/cuda/CUDAContext.h>
15 #include <ATen/cuda/CUDAGraphsUtils.cuh>
16 #include <ATen/cuda/detail/KernelUtils.h>
17 #include <ATen/cuda/detail/IndexUtils.cuh>
18 #include <ATen/native/NonSymbolicBC.h>
19 #include <ATen/native/cuda/Loops.cuh>
20 #include <ATen/native/cuda/MemoryAccess.cuh>
21 #include <ATen/native/cuda/PersistentSoftmax.cuh>
22 #include <ATen/native/cuda/block_reduce.cuh>
23 #include <c10/util/Optional.h>
24
25 #ifndef AT_PER_OPERATOR_HEADERS
26 #include <ATen/Functions.h>
27 #include <ATen/NativeFunctions.h>
28 #else
29 #include <ATen/ops/_efficient_attention_forward.h>
30 #include <ATen/ops/_efficient_attention_forward_native.h>
31 #include <ATen/ops/_fill_mem_eff_dropout_mask_native.h>
32 #include <ATen/ops/_flash_attention_forward.h>
33 #include <ATen/ops/_flash_attention_forward_native.h>
34 #include <ATen/ops/_fused_sdp_choice_native.h>
35 #include <ATen/ops/_masked_softmax.h>
36 #include <ATen/ops/_native_multi_head_attention_native.h>
37 #include <ATen/ops/scaled_dot_product_attention_native.h>
38 #include <ATen/ops/_scaled_dot_product_efficient_attention.h>
39 #include <ATen/ops/_scaled_dot_product_efficient_attention_native.h>
40 #include <ATen/ops/_scaled_dot_product_flash_attention.h>
41 #include <ATen/ops/_scaled_dot_product_flash_attention_native.h>
42 #include <ATen/ops/_softmax.h>
43 #include <ATen/ops/_transform_bias_rescale_qkv.h>
44 #include <ATen/ops/_triton_multi_head_attention_native.h>
45 #include <ATen/ops/_triton_scaled_dot_attention.h>
46 #include <ATen/ops/empty.h>
47 #include <ATen/ops/empty_like.h>
48 #include <ATen/ops/linear.h>
49 #include <ATen/ops/narrow_native.h>
50 #include <ATen/ops/scalar_tensor.h>
51 #include <ATen/ops/scaled_dot_product_attention.h>
52 #include <ATen/ops/split_native.h>
53 #include <ATen/ops/zeros.h>
54 #endif
55
56 #ifdef __HIP_PLATFORM_AMD__
57 #include <ATen/native/cudnn/hip/MHA.h>
58 #else
59 #include <ATen/native/cudnn/MHA.h>
60 #endif
61
62 #include <c10/cuda/CUDAMathCompat.h>
63
64 #include <ATen/native/transformers/attention.h>
65 #include <ATen/native/nested/NestedTensorUtils.h>
66 #include <ATen/native/nested/NestedTensorTransformerFunctions.h>
67 #include <ATen/native/transformers/cuda/sdp_utils.h>
68 #include <ATen/native/transformers/sdp_utils_cpp.h>
69
70 #ifdef USE_FLASH_ATTENTION
71 // FlashAttention Specific Imports
72 #include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
73 #endif
74 #ifdef USE_MEM_EFF_ATTENTION
75 #ifndef USE_ROCM
76 // MemoryEfficient Attention Specific Imports for CUDA
77 #include <ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h>
78 #include <ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h>
79 #include <ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h>
80 #else
81 // MemoryEfficient Attention Specific Imports for ROCM
82 #include <ATen/native/transformers/hip/aotriton_adapter.h>
83 #include <aotriton/flash.h>
84 #include <aotriton/runtime.h>
85 #endif
86 #endif
87
88 namespace at {
89
90 namespace native {
91
92 namespace {
93
94
95 static constexpr int TRANSFORM_BIAS_RESCALE_VEC = 4;
96
97 template <typename scalar_t, typename accscalar_t, bool assume_aligned>
transform_bias_rescale_qkv_kernel(const PackedTensorAccessor64<scalar_t,3,RestrictPtrTraits> qkv,const PackedTensorAccessor64<scalar_t,1,RestrictPtrTraits> qkv_bias,PackedTensorAccessor64<scalar_t,5,RestrictPtrTraits> q_k_v,const scalar_t inv_sqrt_dim_per_head)98 __global__ void transform_bias_rescale_qkv_kernel(
99 // [B, T, 3 * D]
100 const PackedTensorAccessor64<scalar_t, 3, RestrictPtrTraits> qkv,
101 // [3 * D]
102 const PackedTensorAccessor64<scalar_t, 1, RestrictPtrTraits> qkv_bias,
103 // [3, B, NH, T, DH]
104 PackedTensorAccessor64<scalar_t, 5, RestrictPtrTraits> q_k_v,
105 const scalar_t inv_sqrt_dim_per_head) {
106 // warp per DH.
107 // so launch B * NH * T warps.
108 auto NH = q_k_v.size(2);
109 auto T = q_k_v.size(3);
110 auto DH = q_k_v.size(4);
111
112 auto t = blockIdx.x % T;
113 auto b = blockIdx.x / T;
114
115 auto D = NH * DH;
116
117 if (assume_aligned) {
118 constexpr int VEC = TRANSFORM_BIAS_RESCALE_VEC;
119 using LoadT = memory::aligned_vector<scalar_t, VEC>;
120 for (int32_t d_v = threadIdx.x; d_v < D / VEC; d_v += blockDim.x) {
121 auto d = d_v * VEC;
122 auto nh = d / DH;
123 auto dh = d % DH;
124 scalar_t qkv_bias_q[VEC];
125 scalar_t qkv_bias_k[VEC];
126 scalar_t qkv_bias_v[VEC];
127 scalar_t qkv_q[VEC];
128 scalar_t qkv_k[VEC];
129 scalar_t qkv_v[VEC];
130
131 // Here we require D % VEC == 0 for these vectorized loads.
132 *reinterpret_cast<LoadT*>(&qkv_bias_q) =
133 *reinterpret_cast<const LoadT*>(&qkv_bias[d + 0 * D]);
134 *reinterpret_cast<LoadT*>(&qkv_bias_k) =
135 *reinterpret_cast<const LoadT*>(&qkv_bias[d + 1 * D]);
136 *reinterpret_cast<LoadT*>(&qkv_bias_v) =
137 *reinterpret_cast<const LoadT*>(&qkv_bias[d + 2 * D]);
138
139 *reinterpret_cast<LoadT*>(&qkv_q) =
140 *reinterpret_cast<const LoadT*>(&qkv[b][t][d + 0 * D]);
141 *reinterpret_cast<LoadT*>(&qkv_k) =
142 *reinterpret_cast<const LoadT*>(&qkv[b][t][d + 1 * D]);
143 *reinterpret_cast<LoadT*>(&qkv_v) =
144 *reinterpret_cast<const LoadT*>(&qkv[b][t][d + 2 * D]);
145
146 #pragma unroll
147 // TODO: specialize for float2half2/half2float2?
148 for (auto ii = 0; ii < VEC; ++ii) {
149 qkv_q[ii] = static_cast<scalar_t>(
150 (static_cast<accscalar_t>(qkv_q[ii]) +
151 static_cast<accscalar_t>(qkv_bias_q[ii])) *
152 static_cast<accscalar_t>(inv_sqrt_dim_per_head));
153 qkv_k[ii] = static_cast<scalar_t>(
154 (static_cast<accscalar_t>(qkv_k[ii]) +
155 static_cast<accscalar_t>(qkv_bias_k[ii])));
156 qkv_v[ii] = static_cast<scalar_t>(
157 (static_cast<accscalar_t>(qkv_v[ii]) +
158 static_cast<accscalar_t>(qkv_bias_v[ii])));
159 }
160
161 // Here we require DH % VEC == 0 for these vectorized stores.
162 *reinterpret_cast<LoadT*>(&q_k_v[0][b][nh][t][dh]) =
163 *reinterpret_cast<const LoadT*>(&qkv_q);
164 *reinterpret_cast<LoadT*>(&q_k_v[1][b][nh][t][dh]) =
165 *reinterpret_cast<const LoadT*>(&qkv_k);
166 *reinterpret_cast<LoadT*>(&q_k_v[2][b][nh][t][dh]) =
167 *reinterpret_cast<const LoadT*>(&qkv_v);
168 }
169 } else {
170 // Same as above, but we can't vectorize memory access.
171 for (int32_t d = threadIdx.x; d < D; d += blockDim.x) {
172 auto nh = d / DH;
173 auto dh = d % DH;
174 scalar_t qkv_bias_q = qkv_bias[d + 0 * D];
175 scalar_t qkv_bias_k = qkv_bias[d + 1 * D];
176 scalar_t qkv_bias_v = qkv_bias[d + 2 * D];
177 scalar_t qkv_q = qkv[b][t][d + 0 * D];
178 scalar_t qkv_k = qkv[b][t][d + 1 * D];
179 scalar_t qkv_v = qkv[b][t][d + 2 * D];
180 qkv_q = static_cast<scalar_t>(
181 (static_cast<accscalar_t>(qkv_q) +
182 static_cast<accscalar_t>(qkv_bias_q)) *
183 static_cast<accscalar_t>(inv_sqrt_dim_per_head));
184 qkv_k = static_cast<scalar_t>(
185 (static_cast<accscalar_t>(qkv_k) +
186 static_cast<accscalar_t>(qkv_bias_k)));
187 qkv_v = static_cast<scalar_t>(
188 (static_cast<accscalar_t>(qkv_v) +
189 static_cast<accscalar_t>(qkv_bias_v)));
190
191 q_k_v[0][b][nh][t][dh] = qkv_q;
192 q_k_v[1][b][nh][t][dh] = qkv_k;
193 q_k_v[2][b][nh][t][dh] = qkv_v;
194 }
195 }
196 }
197
198 template <typename scalar_t, typename accscalar_t, bool assume_aligned = false>
transform_bias_rescale_qkv_add_padding_kernel(const PackedTensorAccessor64<scalar_t,1,RestrictPtrTraits> qkv,const PackedTensorAccessor64<scalar_t,1,RestrictPtrTraits> qkv_bias,const int * offsets,const int * input_sizes,PackedTensorAccessor64<scalar_t,5,RestrictPtrTraits> q_k_v,const scalar_t inv_sqrt_dim_per_head)199 __global__ void transform_bias_rescale_qkv_add_padding_kernel(
200 // [B, T, 3 * D], but it's a NestedTensor buffer
201 const PackedTensorAccessor64<scalar_t, 1, RestrictPtrTraits> qkv,
202 // [3 * D]
203 const PackedTensorAccessor64<scalar_t, 1, RestrictPtrTraits> qkv_bias,
204 const int* offsets,
205 const int* input_sizes,
206 // [3, B, NH, T, DH]
207 PackedTensorAccessor64<scalar_t, 5, RestrictPtrTraits> q_k_v,
208 const scalar_t inv_sqrt_dim_per_head) {
209 // warp per DH.
210 // so launch B * NH * T warps.
211 const auto NH = q_k_v.size(2);
212 const auto T = q_k_v.size(3);
213 const auto DH = q_k_v.size(4);
214
215 const auto t = blockIdx.x % T;
216 const auto b = blockIdx.x / T;
217
218 const auto D = NH * DH;
219 const auto _3D = 3 * D;
220
221 const auto offset_for_batch = offsets[b];
222 const auto input_dim = 1;
223 const auto* sizes_i = input_sizes + b * input_dim;
224 if (assume_aligned) {
225 constexpr int VEC = TRANSFORM_BIAS_RESCALE_VEC;
226 using LoadT = memory::aligned_vector<scalar_t, VEC>;
227 for (int32_t d_v = threadIdx.x; d_v < D / VEC; d_v += blockDim.x) {
228 auto d = d_v * VEC;
229 auto nh = d / DH;
230 auto dh = d % DH;
231 scalar_t qkv_bias_q[VEC];
232 scalar_t qkv_bias_k[VEC];
233 scalar_t qkv_bias_v[VEC];
234 scalar_t qkv_q[VEC];
235 scalar_t qkv_k[VEC];
236 scalar_t qkv_v[VEC];
237
238 const auto first_item_offset = t * _3D + d;
239 const auto last_item_offset = first_item_offset + VEC - 1;
240 const bool first_item_in_bounds = first_item_offset < sizes_i[0];
241 const bool entire_vec_in_bounds = last_item_offset < sizes_i[0];
242
243 // Here we require D % VEC == 0 for these vectorized loads.
244 *reinterpret_cast<LoadT*>(&qkv_bias_q) =
245 *reinterpret_cast<const LoadT*>(&qkv_bias[d + 0 * D]);
246 *reinterpret_cast<LoadT*>(&qkv_bias_k) =
247 *reinterpret_cast<const LoadT*>(&qkv_bias[d + 1 * D]);
248 *reinterpret_cast<LoadT*>(&qkv_bias_v) =
249 *reinterpret_cast<const LoadT*>(&qkv_bias[d + 2 * D]);
250
251 if (entire_vec_in_bounds) {
252 const auto offset = offset_for_batch + first_item_offset;
253 *reinterpret_cast<LoadT*>(&qkv_q) =
254 *reinterpret_cast<const LoadT*>(&qkv[offset + 0 * D]);
255 *reinterpret_cast<LoadT*>(&qkv_k) =
256 *reinterpret_cast<const LoadT*>(&qkv[offset + 1 * D]);
257 *reinterpret_cast<LoadT*>(&qkv_v) =
258 *reinterpret_cast<const LoadT*>(&qkv[offset + 2 * D]);
259 #pragma unroll
260 // TODO: specialize for float2half2/half2float2?
261 for (auto ii = 0; ii < VEC; ++ii) {
262 qkv_q[ii] = static_cast<scalar_t>(
263 (static_cast<accscalar_t>(qkv_q[ii]) +
264 static_cast<accscalar_t>(qkv_bias_q[ii])) *
265 static_cast<accscalar_t>(inv_sqrt_dim_per_head));
266 qkv_k[ii] = static_cast<scalar_t>(
267 (static_cast<accscalar_t>(qkv_k[ii]) +
268 static_cast<accscalar_t>(qkv_bias_k[ii])));
269 qkv_v[ii] = static_cast<scalar_t>(
270 (static_cast<accscalar_t>(qkv_v[ii]) +
271 static_cast<accscalar_t>(qkv_bias_v[ii])));
272 }
273 } else if (first_item_in_bounds) {
274 const auto offset = offset_for_batch + first_item_offset;
275 qkv_q[0] = qkv[offset + 0 * D];
276 qkv_k[0] = qkv[offset + 1 * D];
277 qkv_v[0] = qkv[offset + 2 * D];
278 qkv_q[0] = static_cast<scalar_t>(
279 (static_cast<accscalar_t>(qkv_q[0]) +
280 static_cast<accscalar_t>(qkv_bias_q[0])) *
281 static_cast<accscalar_t>(inv_sqrt_dim_per_head));
282 qkv_k[0] = static_cast<scalar_t>(
283 (static_cast<accscalar_t>(qkv_k[0]) +
284 static_cast<accscalar_t>(qkv_bias_k[0])));
285 qkv_v[0] = static_cast<scalar_t>(
286 (static_cast<accscalar_t>(qkv_v[0]) +
287 static_cast<accscalar_t>(qkv_bias_v[0])));
288 #pragma unroll
289 for (auto ii = 1; ii < VEC; ++ii) {
290 const auto loop_offset = offset + ii;
291 if (loop_offset < sizes_i[0]) {
292 qkv_q[ii] = qkv[loop_offset + 0 * D];
293 qkv_k[ii] = qkv[loop_offset + 1 * D];
294 qkv_v[ii] = qkv[loop_offset + 2 * D];
295 qkv_q[ii] = static_cast<scalar_t>(
296 (static_cast<accscalar_t>(qkv_q[ii]) +
297 static_cast<accscalar_t>(qkv_bias_q[ii])) *
298 static_cast<accscalar_t>(inv_sqrt_dim_per_head));
299 qkv_k[ii] = static_cast<scalar_t>(
300 (static_cast<accscalar_t>(qkv_k[ii]) +
301 static_cast<accscalar_t>(qkv_bias_k[ii])));
302 qkv_v[ii] = static_cast<scalar_t>(
303 (static_cast<accscalar_t>(qkv_v[ii]) +
304 static_cast<accscalar_t>(qkv_bias_v[ii])));
305 } else {
306 qkv_q[ii] = 0;
307 qkv_k[ii] = 0;
308 qkv_v[ii] = 0;
309 }
310 }
311 } else {
312 #pragma unroll
313 for (auto ii = 0; ii < VEC; ++ii) {
314 qkv_q[ii] = 0;
315 qkv_k[ii] = 0;
316 qkv_v[ii] = 0;
317 }
318 }
319
320 // Here we require DH % VEC == 0 for these vectorized stores.
321 *reinterpret_cast<LoadT*>(&q_k_v[0][b][nh][t][dh]) =
322 *reinterpret_cast<const LoadT*>(&qkv_q);
323 *reinterpret_cast<LoadT*>(&q_k_v[1][b][nh][t][dh]) =
324 *reinterpret_cast<const LoadT*>(&qkv_k);
325 *reinterpret_cast<LoadT*>(&q_k_v[2][b][nh][t][dh]) =
326 *reinterpret_cast<const LoadT*>(&qkv_v);
327 }
328 } else {
329 for (int32_t d = threadIdx.x; d < D; d += blockDim.x) {
330 auto nh = d / DH;
331 auto dh = d % DH;
332 scalar_t qkv_bias_q = qkv_bias[d + 0 * D];
333 scalar_t qkv_bias_k = qkv_bias[d + 1 * D];
334 scalar_t qkv_bias_v = qkv_bias[d + 2 * D];
335
336 const auto item_offset = t * _3D + d;
337 const bool in_bounds = item_offset < sizes_i[0];
338 scalar_t qkv_q, qkv_k, qkv_v;
339 if (in_bounds) {
340 const auto qkv_offset = offset_for_batch + item_offset;
341 qkv_q = qkv[qkv_offset + 0 * D];
342 qkv_k = qkv[qkv_offset + 1 * D];
343 qkv_v = qkv[qkv_offset + 2 * D];
344 qkv_q = static_cast<scalar_t>(
345 (static_cast<accscalar_t>(qkv_q) +
346 static_cast<accscalar_t>(qkv_bias_q)) *
347 static_cast<accscalar_t>(inv_sqrt_dim_per_head));
348 qkv_k = static_cast<scalar_t>(
349 (static_cast<accscalar_t>(qkv_k) +
350 static_cast<accscalar_t>(qkv_bias_k)));
351 qkv_v = static_cast<scalar_t>(
352 (static_cast<accscalar_t>(qkv_v) +
353 static_cast<accscalar_t>(qkv_bias_v)));
354 } else {
355 qkv_q = 0;
356 qkv_k = 0;
357 qkv_v = 0;
358 }
359
360 q_k_v[0][b][nh][t][dh] = qkv_q;
361 q_k_v[1][b][nh][t][dh] = qkv_k;
362 q_k_v[2][b][nh][t][dh] = qkv_v;
363 }
364 }
365 }
366
collapse_dims_1_and_2(const Tensor & sizes)367 Tensor collapse_dims_1_and_2(const Tensor& sizes) {
368 auto sizes_dim1 = at::native::narrow_symint(sizes, 1, 0, 1);
369 auto sizes_dim2 = at::native::narrow_symint(sizes, 1, 1, 1);
370
371 return (sizes_dim1 * sizes_dim2).contiguous();
372 }
373
374 } // namespace
375 // compute q = (q + q_bias) / sqrt(dim_per_head), k = k + k_bias, v = v + v_bias
transform_bias_rescale_qkv_cuda(const Tensor & qkv,const Tensor & qkv_bias,const int64_t num_head)376 __host__ std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_cuda(
377 const Tensor& qkv,
378 const Tensor& qkv_bias,
379 const int64_t num_head) {
380 auto B = qkv.is_nested()
381 ? get_nested_tensor_impl(qkv)->get_nested_sizes().size(0)
382 : qkv.size(0);
383 // TODO: calculate this without the std::vector -- NestedTensor_to_mask wants
384 // this too
385 auto T = qkv.is_nested()
386 ? NestedTensor_get_max_size(*get_nested_tensor_impl(qkv))[0]
387 : qkv.size(1);
388 if (qkv.is_nested()) {
389 // Don't mess with non-nested case for now since it's not set up to fiddle
390 // with mask size.
391
392 // Round T up to next multiple of 8 so as to be able to utilize Tensor
393 // cores. Otherwise, sometimes with padding, *no* row will have the maximum
394 // sequence length and so we'll have a non-divisible-by-8 dimension even if
395 // the model author chose a multiple of 8.
396 T = T + (8 - (T % 8)) % 8;
397 }
398 auto _3D = qkv_bias.size(0);
399 auto D = _3D / 3;
400 TORCH_CHECK(D % num_head == 0);
401 const auto dim_per_head = D / num_head;
402 auto q_k_v = at::empty({3, B, num_head, T, dim_per_head}, qkv_bias.options());
403 #define CALL_KERNEL(assume_aligned) \
404 transform_bias_rescale_qkv_kernel<scalar_t, accscalar_t, assume_aligned> \
405 <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
406 qkv.packed_accessor64<scalar_t, 3, RestrictPtrTraits>(), \
407 qkv_bias.packed_accessor64<scalar_t, 1, RestrictPtrTraits>(), \
408 q_k_v.packed_accessor64<scalar_t, 5, RestrictPtrTraits>(), \
409 1.0 / std::sqrt(static_cast<scalar_t>(dim_per_head)))
410 #define CALL_ADD_PADDING_KERNEL(assume_aligned) \
411 transform_bias_rescale_qkv_add_padding_kernel< \
412 scalar_t, \
413 accscalar_t, \
414 assume_aligned> \
415 <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
416 nt_qkv_buffer \
417 .packed_accessor64<scalar_t, 1, RestrictPtrTraits>(), \
418 qkv_bias.packed_accessor64<scalar_t, 1, RestrictPtrTraits>(), \
419 offsets_ptr, \
420 sizes_ptr, \
421 q_k_v.packed_accessor64<scalar_t, 5, RestrictPtrTraits>(), \
422 1.0 / std::sqrt(static_cast<scalar_t>(dim_per_head)))
423
424 AT_DISPATCH_FLOATING_TYPES_AND2(
425 ScalarType::Half,
426 ScalarType::BFloat16,
427 qkv.scalar_type(),
428 "transform_bias_rescale_qkv",
429 [&] {
430 using accscalar_t = acc_type<scalar_t, true>;
431 auto threads = std::max(
432 std::min<int32_t>(1024, D / TRANSFORM_BIAS_RESCALE_VEC), 1);
433 auto blocks = B * T;
434 const bool aligned =
435 ((dim_per_head % TRANSFORM_BIAS_RESCALE_VEC) == 0) &&
436 ((reinterpret_cast<intptr_t>(qkv_bias.data_ptr()) %
437 TRANSFORM_BIAS_RESCALE_VEC) == 0);
438 if (aligned) {
439 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
440 D % TRANSFORM_BIAS_RESCALE_VEC == 0,
441 "D = num_heads * dim_per_head, so we should have dim_per_head % "
442 "TRANSFORM_BIAS_RESCALE_VEC == 0 => "
443 "D % TRANSFORM_BIAS_RESCALE_VEC == 0");
444 }
445 if (qkv.is_nested()) {
446 auto* nt_qkv = get_nested_tensor_impl(qkv);
447 const at::Tensor& nt_qkv_buffer = nt_qkv->get_buffer();
448 auto sizes = collapse_dims_1_and_2(nt_qkv->get_nested_sizes());
449 auto offsets =
450 NestedTensor_batch_offsets_from_size_tensor(sizes, sizes.numel());
451 at::native::narrow_symint(offsets, 0, sizes.numel() + 1, sizes.numel())
452 .copy_(sizes.reshape({-1}));
453 auto metadata = offsets.to(at::Device(kCUDA), at::kInt, true, true);
454 const auto offsets_ptr = metadata.data_ptr<int>();
455 const auto sizes_ptr = offsets_ptr + sizes.numel() + 1;
456 const auto input_dim = sizes.sizes()[1];
457 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input_dim == 1);
458 if (aligned &&
459 ((reinterpret_cast<intptr_t>(qkv.data_ptr()) %
460 TRANSFORM_BIAS_RESCALE_VEC) == 0)) {
461 CALL_ADD_PADDING_KERNEL(true);
462 } else {
463 CALL_ADD_PADDING_KERNEL(false);
464 }
465 } else if (aligned) {
466 CALL_KERNEL(true);
467 } else {
468 CALL_KERNEL(false);
469 }
470 C10_CUDA_KERNEL_LAUNCH_CHECK();
471 });
472 #undef CALL_ADD_PADDING_KERNEL
473 #undef CALL_KERNEL
474 auto q_k_v_s =
475 at::native::split(q_k_v.view({3 * B, num_head, T, dim_per_head}), B, 0);
476 return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]);
477 }
478
native_multi_head_attention_cuda(const Tensor & query,const Tensor & key,const Tensor & value,const int64_t embed_dim,const int64_t num_head,const Tensor & qkv_weight,const Tensor & qkv_bias,const Tensor & proj_weight,const Tensor & proj_bias,const std::optional<Tensor> & mask,bool need_weights,bool average_attn_weights,const std::optional<int64_t> mask_type)479 std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
480 const Tensor& query,
481 const Tensor& key,
482 const Tensor& value,
483 const int64_t embed_dim,
484 const int64_t num_head,
485 const Tensor& qkv_weight,
486 const Tensor& qkv_bias,
487 const Tensor& proj_weight,
488 const Tensor& proj_bias,
489 const std::optional<Tensor>& mask,
490 bool need_weights,
491 bool average_attn_weights,
492 const std::optional<int64_t> mask_type) {
493 // query shape: [B, T, D]
494 // qkv_weight shape: [3 * D, D]
495
496 TORCH_CHECK(
497 !mask || !query.is_nested(),
498 "NestedTensor with mask is not supported yet");
499 const auto D = embed_dim;
500 TORCH_CHECK(
501 query.dim() == 3,
502 "expected 3-D `query`, got ",
503 query.dim(),
504 "-D tensor");
505 TORCH_CHECK(
506 query.is_nested() || query.sizes()[2] == embed_dim,
507 "passed-in embed_dim ",
508 embed_dim,
509 " didn't match last dim of query ",
510 query.sizes()[2]);
511 TORCH_CHECK(
512 key.dim() == 3,
513 "expected 3-D `key`, got ",
514 key.dim(),
515 "-D tensor");
516 TORCH_CHECK(
517 value.dim() == 3,
518 "expected 3-D `value`, got ",
519 value.dim(),
520 "-D tensor");
521 TORCH_CHECK(
522 query.is_nested() || key.is_nested() || value.is_nested() ||
523 (query.sizes() == key.sizes() && key.sizes() == value.sizes()),
524 "expected `query`/`key`/`value` shapes to match");
525 TORCH_CHECK(
526 qkv_weight.dim() == 2,
527 "expected 2-D `qkv_weight`, got ",
528 qkv_weight.dim(),
529 "-D tensor");
530 TORCH_CHECK(
531 D * 3 == qkv_weight.sizes()[0],
532 "expected `qkv_weight` first dim to be 3x embed_dim");
533 TORCH_CHECK(
534 D == qkv_weight.sizes()[1],
535 "expected `qkv_weight` second dim to be embed_Dim");
536 TORCH_CHECK(
537 qkv_bias.dim() == 1,
538 "expected 1-D `qkv_bias`, got ",
539 qkv_bias.dim(),
540 "-D tensor");
541 TORCH_CHECK(
542 qkv_bias.sizes()[0] == 3 * D,
543 "expected `qkv_bias` first dim and first dim of query to be equal");
544 TORCH_CHECK(D % num_head == 0, "`embed_dim` must divide evenly by `num_heads`");
545
546 #ifndef NDEBUG
547 const auto B = query.is_nested()
548 ? get_nested_tensor_impl(query)->get_nested_sizes().size(0)
549 : query.sizes()[0];
550 auto T = query.is_nested() ? 0 : query.sizes()[1];
551
552 #endif
553 const auto dim_per_head = D / num_head;
554 if ((query.is_same(key) && key.is_same(value)) && dim_per_head % 8 == 0 && !need_weights) {
555
556 // We have not done linear projection yet but the input for SDP
557 // Is expected to be 4 dimensional. We "cheaply" create view tensors
558 // That will then be used for checking hot path conditions with select_sd_backend
559 auto q = query.view({query.size(0), -1, num_head, dim_per_head}).transpose(1, 2);
560 auto k = key.view({key.size(0), -1, num_head, dim_per_head}).transpose(1, 2);
561 auto v = value.view({value.size(0), -1, num_head, dim_per_head}).transpose(1, 2);
562
563 sdp::sdp_params kernel_params{q, k, v, mask, 0.0, false};
564 auto backend = select_sdp_backend(kernel_params);
565 // strides from packed projection for nested tensors when seq_len is 1 will be
566 // and will trigger a contiguous call in the kernel, so we prevent this
567 bool no_seq_len_1_nested = query.is_nested() ? check_for_seq_len_1_nested_tensor(kernel_params, false) : true;
568 // The API for transformer_encoder is a mask of shape (Batch_Size, Seq_len_q)
569 // For mem-eff attention this will cause the expand call to error
570 // For now I am going to turn of that path not have to deal with all the annoying
571 // Mask type shape grossness
572 if (!mask.has_value() && no_seq_len_1_nested &&
573 (backend == sdp::SDPBackend::flash_attention || backend == sdp::SDPBackend::efficient_attention ||
574 backend == sdp::SDPBackend::cudnn_attention)) {
575 auto x = at::linear(query, qkv_weight, qkv_bias);
576 auto chunks = x.chunk(3, -1);
577 auto x_size_0 = x.size(0);
578
579 chunks[0] = (chunks[0].view({x_size_0, -1, num_head, dim_per_head}))
580 .transpose(1, 2);
581 chunks[1] = (chunks[1].view({x_size_0, -1, num_head, dim_per_head}))
582 .transpose(1, 2);
583 chunks[2] = (chunks[2].view({x_size_0, -1, num_head, dim_per_head}))
584 .transpose(1, 2);
585 auto y = at::scaled_dot_product_attention(
586 chunks[0], chunks[1], chunks[2], mask, 0.0, false, c10::nullopt);
587
588 auto past_sdp = y.transpose(1, 2).reshape({x_size_0, -1, embed_dim});
589 return std::make_tuple(
590 at::linear(past_sdp, proj_weight, proj_bias), Tensor());
591 }
592 // Returned math or error lets not use it
593 }
594
595 // shape: [B, T, 3 x D]
596 auto qkv = qkv_projection(query, key, value, embed_dim, qkv_weight);
597
598 if (!qkv.is_nested() && qkv.numel() == 0) {
599 if (query.is_nested()) {
600 return std::make_tuple(Tensor(), Tensor());
601 }
602 return std::make_tuple(at::empty_like(query), Tensor());
603 }
604
605 #ifndef NDEBUG
606 if (!query.is_nested() || !qkv.is_nested()) {
607 if (query.is_nested()) {
608 T = qkv.size(1);
609 }
610 debug_assert_shape(__LINE__, qkv, {B, T, 3 * D});
611 }
612 #endif
613
614 #ifdef DEBUG_PRINT_EACH_STEP
615 if (!qkv.is_nested()) {
616 std::cerr << "qkv: " << qkv << std::endl;
617 }
618 #endif
619 // shape: 3 x [B, num_head, T, dim_per_head]
620 auto q_k_v = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
621 qkv = Tensor(); // Not used any more, allow free
622 auto& q = std::get<0>(q_k_v);
623 const auto& k = std::get<1>(q_k_v);
624 const auto& v = std::get<2>(q_k_v);
625 #ifndef NDEBUG
626 debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head});
627 debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head});
628 debug_assert_shape(__LINE__, v, {B, num_head, T, dim_per_head});
629 #endif
630 #ifdef DEBUG_PRINT_EACH_STEP
631 std::cerr << "q: " << q << std::endl;
632 std::cerr << "k: " << k << std::endl;
633 std::cerr << "v: " << v << std::endl;
634 #endif
635
636 // shape: [B, num_head, T, T]
637 auto qkt = bmm_nt(q, k);
638 // q & k are dead but cannot be freed because they were packed with v
639 #ifndef NDEBUG
640 debug_assert_shape(__LINE__, qkt, {B, num_head, T, T});
641 #endif
642 #ifdef DEBUG_PRINT_EACH_STEP
643 std::cerr << "qkt: " << qkt << std::endl;
644 #endif
645
646 // shape: [B, num_head, T, T]
647 // TODO: long-term, have a kernel that works with
648 // NestedTensor directly if there is no mask passed
649 qkt = masked_softmax(qkt, mask, query, mask_type);
650 #ifdef DEBUG_PRINT_EACH_STEP
651 std::cerr << "qkt after softmax: " << qkt << std::endl;
652 #endif
653
654 // shape: [B, num_head, T, dim_per_head]
655 // reuse storage for q; we're done with it
656 auto attn_ctx = bmm_nn(q, qkt, v);
657 // qkv is not dead; we just reused storage for q!
658 if (!need_weights) {
659 qkt = Tensor();
660 }
661 #ifndef NDEBUG
662 debug_assert_shape(__LINE__, attn_ctx, {B, num_head, T, dim_per_head});
663 #endif
664 #ifdef DEBUG_PRINT_EACH_STEP
665 std::cerr << "attn_ctx: " << attn_ctx << std::endl;
666 #endif
667
668 // shape: [B, T, D]
669 // Fuse transform_0213 inside
670 auto proj = transform0213_gemm_nt_bias(
671 attn_ctx, proj_weight, proj_bias, query);
672 #ifndef NDEBUG
673 debug_assert_shape(__LINE__, proj, {B, T, D});
674 #endif
675 if (need_weights && average_attn_weights) {
676 // weights are not needed for full transformer, so don't worry too
677 // much about performance -- we implement this just to make use
678 // cases that don't disable need_weights still get some speedup.
679 qkt = qkt.sum(1);
680 qkt /= num_head;
681 }
682 return std::make_tuple(std::move(proj), std::move(qkt));
683 }
_scaled_dot_product_flash_attention_cuda(const Tensor & query,const Tensor & key,const Tensor & value,double dropout_p,bool is_causal,bool return_debug_mask,std::optional<double> scale)684 std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_cuda(
685 const Tensor& query,
686 const Tensor& key,
687 const Tensor& value,
688 double dropout_p,
689 bool is_causal,
690 bool return_debug_mask,
691 std::optional<double> scale) {
692 // Used for tracking usage statistics
693 C10_LOG_API_USAGE_ONCE("torch.sdpa.flash_attention");
694 // Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
695 // Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
696 // Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
697
698 const int64_t max_seqlen_batch_q = query.size(2);
699 const int64_t max_seqlen_batch_k = key.size(2);
700 const int64_t max_seqlen_batch_v = value.size(2);
701 TORCH_CHECK(
702 max_seqlen_batch_k == max_seqlen_batch_v,
703 "Key and Value must have the same sequence length");
704
705 // Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head)
706 // Key -> Key (Batch x KV_seq_len x Num_heads x Dim_per_head)
707 // Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head)
708 Tensor q_t = query.transpose(1, 2);
709 Tensor k_t = key.transpose(1, 2);
710 Tensor v_t = value.transpose(1, 2);
711
712 auto
713 [output,
714 logsumexp,
715 philox_seed,
716 philox_offset,
717 debug_attn_mask] =
718 at::_flash_attention_forward(
719 q_t,
720 k_t,
721 v_t,
722 c10::nullopt,
723 c10::nullopt,
724 max_seqlen_batch_q,
725 max_seqlen_batch_k,
726 dropout_p,
727 is_causal,
728 return_debug_mask,
729 scale,
730 c10::nullopt,
731 c10::nullopt);
732 // Reshape output to convert nnz to batch_size and seq_len
733 Tensor attention = output.transpose(1,2);
734
735 return std::make_tuple(attention, logsumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, philox_seed, philox_offset, debug_attn_mask);
736 }
737
_scaled_dot_product_cudnn_attention_cuda(const Tensor & query,const Tensor & key,const Tensor & value,double dropout_p,bool is_causal,bool training,std::optional<double> scale)738 std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_cuda(
739 const Tensor& query,
740 const Tensor& key,
741 const Tensor& value,
742 double dropout_p,
743 bool is_causal,
744 bool training,
745 std::optional<double> scale) {
746 // Used for tracking usage statistics
747 C10_LOG_API_USAGE_ONCE("torch.sdpa.flash_attention_cudnn");
748 // Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
749 // Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
750 // Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
751 const int64_t batch_size = query.size(0);
752 const int64_t num_heads = query.size(1);
753 const int64_t max_seqlen_batch_q = query.size(2);
754 const int64_t head_dim_qk = query.size(3);
755 const int64_t head_dim_v = value.size(3);
756 const int64_t max_seqlen_batch_k = key.size(2);
757 const int64_t max_seqlen_batch_v = value.size(2);
758 TORCH_CHECK(
759 max_seqlen_batch_k == max_seqlen_batch_v,
760 "Key and Value must have the same sequence length");
761
762 Tensor attention, log_sumexp;
763
764 auto cudnn_seed = at::zeros({1}, query.options().dtype(kLong));
765 auto cudnn_offset = at::zeros({1}, query.options().dtype(kLong));
766 const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
767
768 run_cudnn_SDP_fprop(batch_size/*int64_t b*/,
769 num_heads/*int64_t h*/,
770 max_seqlen_batch_q/*int64_t s_q*/,
771 max_seqlen_batch_k/*int64_t s_kv*/,
772 head_dim_qk/*int64_t d_qk*/,
773 head_dim_v/*int64_t d_v*/,
774 softmax_scale/*float scaling_factor*/,
775 training/* bool */,
776 is_causal/* bool */,
777 dropout_p/*double dropout_probability*/,
778 query/* Tensor q*/,
779 key/* Tensor k*/,
780 value/* Tensor v*/,
781 log_sumexp/*Tensor softmaxstats*/,
782 attention/*Tensor o*/,
783 cudnn_seed/*Tensor dropoutseed*/,
784 cudnn_offset/*Tensor dropoutoffset*/);
785
786 return std::make_tuple(attention, log_sumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, cudnn_seed, cudnn_offset, Tensor());
787 }
788
_scaled_dot_product_efficient_attention_cuda(const Tensor & query,const Tensor & key,const Tensor & value,const std::optional<at::Tensor> & attn_bias,bool compute_log_sumexp,double dropout_p,bool is_causal,std::optional<double> scale)789 std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attention_cuda(
790 const Tensor& query,
791 const Tensor& key,
792 const Tensor& value,
793 const std::optional<at::Tensor>& attn_bias,
794 bool compute_log_sumexp,
795 double dropout_p,
796 bool is_causal,
797 std::optional<double> scale) {
798 // Used for tracking usage statistics
799 C10_LOG_API_USAGE_ONCE("torch.sdpa.mem_efficient_attention");
800 // Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head)
801 // Key -> Key(Batch x KV_seq_len x Num_heads x Dim_per_head)
802 // Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head)
803 Tensor q_t = query.transpose(1, 2);
804 Tensor k_t = key.transpose(1, 2);
805 Tensor v_t = value.transpose(1, 2);
806
807 sdp::CustomMaskType custom_mask_type = is_causal
808 ? sdp::CustomMaskType::CausalFromTopLeft
809 : sdp::CustomMaskType::NoCustomMask;
810
811 auto [attention, log_sumexp, seed, offset, max_seqlen_batch_q, max_seqlen_batch_kv] = at::_efficient_attention_forward(
812 q_t,
813 k_t,
814 v_t,
815 attn_bias,
816 c10::nullopt,
817 c10::nullopt,
818 c10::nullopt,
819 c10::nullopt,
820 dropout_p,
821 static_cast<int64_t>(custom_mask_type),
822 compute_log_sumexp,
823 scale);
824
825 attention = attention.transpose(1, 2);
826 return std::make_tuple(std::move(attention), std::move(log_sumexp), std::move(seed), std::move(offset));
827 }
828
_fused_sdp_choice_cuda(const Tensor & query_,const Tensor & key,const Tensor & value,const std::optional<Tensor> & attn_mask_,double dropout_p,bool is_causal,std::optional<double> scale)829 int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value,
830 const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale){
831 sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal};
832 auto backend = select_sdp_backend(kernel_params);
833 if (backend == sdp::SDPBackend::error) {
834 TORCH_CHECK(
835 false,
836 "No viable backend for scaled_dot_product_attention was found. ",
837 "This is likely due to turning off both the math kernel and the fused kernels.");
838 }
839 return static_cast<int64_t>(backend);
840 }
841
842 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>
_flash_attention_forward(const Tensor & query,const Tensor & key,const Tensor & value,const std::optional<Tensor> & cumulative_sequence_length_q,const std::optional<Tensor> & cumulative_sequence_length_k,int64_t max_seqlen_batch_q,int64_t max_seqlen_batch_k,double dropout_p,bool is_causal,bool return_debug_mask,std::optional<double> scale,std::optional<int64_t> window_size_left,std::optional<int64_t> window_size_right,const std::optional<Tensor> & _seqused_k,const std::optional<Tensor> & _alibi_slopes)843 _flash_attention_forward(
844 const Tensor& query,
845 const Tensor& key,
846 const Tensor& value,
847 const std::optional<Tensor>& cumulative_sequence_length_q,
848 const std::optional<Tensor>& cumulative_sequence_length_k,
849 int64_t max_seqlen_batch_q,
850 int64_t max_seqlen_batch_k,
851 double dropout_p,
852 bool is_causal,
853 bool return_debug_mask,
854 std::optional<double> scale,
855 std::optional<int64_t> window_size_left,
856 std::optional<int64_t> window_size_right,
857 const std::optional<Tensor>& _seqused_k,
858 const std::optional<Tensor>& _alibi_slopes
859 ) {
860 #if defined(USE_FLASH_ATTENTION)
861 const auto softmax_scale =
862 sdp::calculate_scale(query, scale).as_float_unchecked();
863 std::optional<Tensor> out = c10::nullopt;
864
865 std::optional<Tensor> seqused_k = _seqused_k;
866 std::optional<Tensor> alibi_slopes = _alibi_slopes;
867
868 const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
869 const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1;
870
871 // We are going to have two paths:
872 // 1. The standard MHA path for dense tensors
873 // 2. The Varseqlen path
874 TORCH_CHECK(
875 cumulative_sequence_length_q.has_value() ==
876 cumulative_sequence_length_k.has_value(),
877 "cumulative_sequence_length_q and cumulative_sequence_length_k must be both set or both not set");
878 Tensor output, q_padded, k_padded, v_padded, logsumexp, output_shape,
879 philox_seed, philox_offset, debug_attn_mask;
880 if (cumulative_sequence_length_q.has_value()) {
881 std::tie(
882 output,
883 q_padded,
884 k_padded,
885 v_padded,
886 logsumexp,
887 philox_seed,
888 philox_offset,
889 debug_attn_mask) =
890 pytorch_flash::mha_varlen_fwd(
891 query,
892 key,
893 value,
894 out,
895 cumulative_sequence_length_q.value(),
896 cumulative_sequence_length_k.value(),
897 seqused_k, /*seqused_k*/
898 alibi_slopes, /*alibi_slopes*/
899 max_seqlen_batch_q,
900 max_seqlen_batch_k,
901 dropout_p,
902 softmax_scale,
903 false /*zero_tensors*/,
904 is_causal,
905 non_null_window_left,
906 non_null_window_right,
907 return_debug_mask,
908 c10::nullopt /*gen_*/);
909 } else {
910 std::tie(
911 output,
912 q_padded,
913 k_padded,
914 v_padded,
915 logsumexp,
916 philox_seed,
917 philox_offset,
918 debug_attn_mask) =
919 pytorch_flash::mha_fwd(
920 query,
921 key,
922 value,
923 out,
924 alibi_slopes,
925 dropout_p,
926 softmax_scale,
927 is_causal,
928 non_null_window_left,
929 non_null_window_right,
930 return_debug_mask, /*return_softmax (this is used for testing)*/
931 c10::nullopt);
932 }
933 debug_attn_mask =
934 return_debug_mask ? debug_attn_mask : at::empty({0}, query.options());
935 return std::make_tuple(
936 std::move(output),
937 std::move(logsumexp),
938 std::move(philox_seed),
939 std::move(philox_offset),
940 std::move(debug_attn_mask));
941
942 #endif
943 TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.")
944 return std::make_tuple(
945 Tensor(),
946 Tensor(),
947 Tensor(),
948 Tensor(),
949 Tensor());
950 }
951
_efficient_attention_forward(const at::Tensor & query,const at::Tensor & key,const at::Tensor & value,const std::optional<at::Tensor> & bias,const std::optional<at::Tensor> & seqstart_q,const std::optional<at::Tensor> & seqstart_k,const std::optional<int64_t> max_seqlen_q_,const std::optional<int64_t> max_seqlen_k_,double dropout_p,int64_t custom_mask_type,bool compute_logsumexp,std::optional<double> scale,const std::optional<at::Tensor> & seqlen_k,const std::optional<int64_t> window_size)952 std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_attention_forward(
953 const at::Tensor& query, // [b, seqlen, num_heads, K]
954 const at::Tensor& key, // [b, seqlen, num_heads, K]
955 const at::Tensor& value, // [b, seqlen, num_heads, Kv]
956 const std::optional<at::Tensor>& bias, // [b, num_heads, seqlen, seqlen]
957 // (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the
958 // position of the first query token for batch $b
959 const std::optional<at::Tensor>& seqstart_q,
960 // (Mode 1MHK only) [b+1]: cu_seqlen_k[b] contains the
961 // position of the first key token for batch $b
962 const std::optional<at::Tensor>& seqstart_k,
963 // (Mode 1MHK only) Maximum sequence length across batches
964 const std::optional<int64_t> max_seqlen_q_,
965 const std::optional<int64_t> max_seqlen_k_,
966 double dropout_p, // attention matrix dropout probability
967 int64_t custom_mask_type,
968 bool compute_logsumexp,
969 std::optional<double> scale,
970 const std::optional<at::Tensor>& seqlen_k,
971 const std::optional<int64_t> window_size) {
972 #if defined(USE_MEM_EFF_ATTENTION)
973 // TODO In theory it is possible to compile with _CUDA_ARCH < 5.0 and run on a
974 // machine that is >= 5.0. In practice, this is not a problem but since
975 // this would avoid runtime architecture checks, we should look into it
976
977 TORCH_CHECK(query.dim() == 4);
978 TORCH_CHECK(key.dim() == 4);
979 TORCH_CHECK(value.dim() == 4);
980
981 // Batch sizes
982 TORCH_CHECK(query.size(0) == key.size(0));
983 TORCH_CHECK(query.size(0) == value.size(0));
984
985 // Sequence length
986 TORCH_CHECK(key.size(1) == value.size(1));
987
988 // Num heads
989 TORCH_CHECK(query.size(2) == key.size(2));
990 TORCH_CHECK(query.size(2) == value.size(2));
991
992 // Embedding per head
993 TORCH_CHECK(query.size(3) == key.size(3));
994
995 int64_t max_seqlen_q = 0, max_seqlen_k = 0;
996 TORCH_CHECK(seqstart_q.has_value() == seqstart_k.has_value());
997 if (seqstart_q.has_value()) {
998 TORCH_CHECK(seqstart_q->scalar_type() == at::ScalarType::Int);
999 TORCH_CHECK(seqstart_k->scalar_type() == at::ScalarType::Int);
1000 TORCH_CHECK(seqstart_q->dim() == 1 && seqstart_k->dim() == 1);
1001 CHECK_NOSPARSE_CONTIGUOUS_CUDA((*seqstart_q));
1002 CHECK_NOSPARSE_CONTIGUOUS_CUDA((*seqstart_k));
1003 TORCH_CHECK(seqstart_q->size(0) == seqstart_k->size(0));
1004 TORCH_CHECK(query.size(0) == 1, "cu_seqlen only supports batch_size=1");
1005 TORCH_CHECK(max_seqlen_q_.has_value());
1006 max_seqlen_q = *max_seqlen_q_;
1007 max_seqlen_k = 0; // TODO: is this actually being set inside the kernel anywhere?
1008 // see https://github.com/pytorch/pytorch/issues/115590s
1009 } else {
1010 max_seqlen_q = query.size(1);
1011 max_seqlen_k = key.size(1);
1012 }
1013
1014 CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(query);
1015 CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(key);
1016 CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(value);
1017
1018 at::cuda::CUDAGuard device_guard(query.device());
1019 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
1020
1021 int64_t B = query.size(0);
1022 int64_t M = query.size(1);
1023 int64_t N = key.size(1);
1024 int64_t num_heads = query.size(-2);
1025 int64_t K = query.size(-1);
1026 int64_t Kv = value.size(-1);
1027
1028 at::Tensor res;
1029 at::Tensor logsumexp;
1030 at::Tensor seed_t, offset_t;
1031
1032 const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO;
1033
1034 // Note [Seed and Offset Device]
1035 // If we are currently in graph capture mode, we need to create the seed and offset tensors on the device.
1036 // This is necessary for CUDA graph-safe random number generation, which requires the seed and offset tensors
1037 // to be single element tensors on device. During graph capture, when the seed and offset tensors are passed
1038 // the pointers act as scratch space for storing the RNG state for the backwards pass.
1039 // When calling backwards, we either construct a PhiloxState with the pointers or the actual values.
1040 // For more information on CUDA graph-safe RNG states, see Note [CUDA Graph-safe RNG states].
1041
1042 at::PhiloxCudaState philox_state;
1043 const bool in_capture_stream =
1044 at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None;
1045 auto device = in_capture_stream ? at::kCUDA : at::kCPU;
1046 if (use_dropout) {
1047 auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
1048 c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
1049
1050 // See Note [Acquire lock when using random generators]
1051 std::lock_guard<std::mutex> lock(gen->mutex_);
1052 // if using dropout, we produce 1 random number for each element of the
1053 // attention tensor
1054 philox_state = gen->philox_cuda_state(B * num_heads * M * N);
1055
1056 if (in_capture_stream) {
1057 // The seed and offset will be populated by the kernel
1058 seed_t = at::empty({}, at::dtype(at::kLong).device(device));
1059 offset_t = at::empty({}, at::dtype(at::kLong).device(device));
1060 } else {
1061 auto [seed, offset] = at::cuda::philox::unpack(philox_state);
1062 seed_t = at::scalar_tensor(
1063 at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong));
1064 offset_t = at::scalar_tensor(
1065 at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong));
1066 }
1067 } else {
1068 // Not using dropout
1069 seed_t = at::empty({}, at::dtype(at::kLong).device(device));
1070 offset_t = at::empty({}, at::dtype(at::kLong).device(device));
1071 }
1072
1073 #ifdef USE_ROCM
1074 // ROCM Implementation
1075 auto ret = aotriton::v2::flash::check_gpu(stream);
1076 if (hipSuccess != ret) {
1077 TORCH_CHECK(false,
1078 "[AOTriton] Accelerated SDPA only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx94a:sramecc+:xnack-)")
1079 }
1080
1081 // AOTriton may accept aligned on logsumexp tensor in the future for better
1082 // performance, but for now it requires compact logsumexp tensor, even if
1083 // compute_logsumexp is false
1084 constexpr int kAlignLSE = 1;
1085 res = at::empty({B, M, num_heads, Kv}, query.options());
1086 logsumexp = at::empty(
1087 { B, num_heads, max_seqlen_q },
1088 query.options().dtype(at::ScalarType::Float));
1089 at::Tensor softmax_lse = logsumexp.view({B * num_heads, max_seqlen_q});
1090 at::Tensor q_t = query.transpose(1, 2);
1091 at::Tensor k_t = key.transpose(1, 2);
1092 at::Tensor v_t = value.transpose(1, 2);
1093 at::Tensor output_t = res.transpose(1, 2);
1094 bool is_causal;
1095 if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
1096 is_causal = true;
1097 } else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
1098 is_causal = false;
1099 } else {
1100 TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
1101 }
1102
1103 const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
1104
1105 using aotriton::v2::flash::attn_fwd;
1106 using sdp::aotriton_adapter::mk_aotensor;
1107 aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, aotriton::DType::kFloat16);
1108 at::Tensor softmax_fa_t = at::empty({ 0, 0, 0, 0 }, query.options());
1109 hipError_t err; // TODO: Error handling
1110 err = attn_fwd(mk_aotensor(q_t, "q"),
1111 mk_aotensor(k_t, "k"),
1112 mk_aotensor(v_t, "v"),
1113 bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
1114 softmax_scale,
1115 mk_aotensor<2>(softmax_lse, "M"),
1116 mk_aotensor(output_t, "Out"),
1117 dropout_p,
1118 use_dropout ? *seed_t.data_ptr<int64_t>() : 0,
1119 use_dropout ? *offset_t.data_ptr<int64_t>() : 0,
1120 mk_aotensor(softmax_fa_t, "encoded_softmax"),
1121 is_causal,
1122 stream);
1123 if (!compute_logsumexp) {
1124 // Set the tensor to empty when compute_logsumexp is false
1125 logsumexp = at::empty(
1126 { B * num_heads, max_seqlen_q, 0 },
1127 query.options().dtype(at::ScalarType::Float));
1128 }
1129 #else
1130 // CUDA Implementation
1131 cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());
1132 const int computeCapability = p->major * 10 + p->minor;
1133
1134 bool kernel_launched = false;
1135 const auto maxShmem = p->sharedMemPerBlockOptin;
1136
1137 auto launchKernel = [&](auto _k, auto kernel_fn) {
1138 using Kernel = decltype(_k);
1139 using scalar_t = typename Kernel::scalar_t;
1140 (void)_k;
1141
1142 if (kernel_launched) {
1143 return;
1144 }
1145 // Check if this kernel is compatible
1146 if (!Kernel::kSupportsDropout && use_dropout) {
1147 return;
1148 }
1149 if (!Kernel::kSupportsBias && bias.has_value()) {
1150 return;
1151 }
1152
1153 if (value.size(3) > Kernel::kMaxK || key.size(3) > Kernel::kMaxK) {
1154 return;
1155 }
1156 // Alignment
1157 if ((query.stride(2) % Kernel::kAlignmentQ) ||
1158 (key.stride(2) % Kernel::kAlignmentK) ||
1159 (value.stride(2) % Kernel::kAlignmentV)) {
1160 return;
1161 }
1162 // Uses too much shmem
1163 size_t smem_bytes = sizeof(typename Kernel::SharedStorage);
1164 if (smem_bytes > maxShmem) {
1165 return;
1166 }
1167 kernel_launched = true;
1168
1169 res = at::empty(
1170 {B, M, num_heads, Kv},
1171 query.options().dtype(
1172 CutlassToAtenDtype<typename Kernel::output_t>::atScalarType()));
1173
1174 // NOTE: Should be aligned (by padding) in case M is
1175 // not a good number for loading during backward
1176 constexpr decltype(M) kAlignLSE = Kernel::kAlignLSE;
1177 logsumexp = at::empty(
1178 {seqstart_q.has_value() ? seqstart_q->size(0) - 1 : B,
1179 num_heads,
1180 compute_logsumexp ? ceil_div(max_seqlen_q, kAlignLSE) * kAlignLSE : 0},
1181 query.options().dtype(at::ScalarType::Float));
1182 typename Kernel::Params p;
1183 p.query_ptr = (const scalar_t*)query.const_data_ptr();
1184 p.key_ptr = (const scalar_t*)key.const_data_ptr();
1185 p.value_ptr = (const scalar_t*)value.const_data_ptr();
1186 p.logsumexp_ptr = compute_logsumexp
1187 ? (typename Kernel::lse_scalar_t*)logsumexp.data_ptr()
1188 : nullptr;
1189 at::Tensor output_accum;
1190 if (Kernel::kNeedsOutputAccumulatorBuffer) {
1191 output_accum = at::empty(
1192 {B, M, num_heads, Kv},
1193 query.options().dtype(
1194 CutlassToAtenDtype<
1195 typename Kernel::output_accum_t>::atScalarType()));
1196 p.output_accum_ptr =
1197 (typename Kernel::output_accum_t*)output_accum.data_ptr();
1198 } else {
1199 p.output_accum_ptr = nullptr;
1200 }
1201 p.output_ptr = (typename Kernel::output_t*)res.data_ptr();
1202
1203 if (seqstart_q.has_value()) {
1204 p.seqstart_q_ptr = (const int32_t*)seqstart_q->const_data_ptr();
1205 p.seqstart_k_ptr = (const int32_t*)seqstart_k->const_data_ptr();
1206 }
1207
1208 p.num_heads = num_heads;
1209 p.head_dim = query.size(3);
1210 p.head_dim_value = value.size(3);
1211 p.num_queries = max_seqlen_q;
1212 p.num_keys = max_seqlen_k;
1213 p.num_batches = seqstart_q.has_value() ? seqstart_q->size(0) - 1 : B;
1214 p.custom_mask_type = custom_mask_type;
1215
1216 p.seqlen_k_ptr = nullptr;
1217 if (seqlen_k.has_value()) {
1218 CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(seqlen_k.value());
1219 TORCH_CHECK(seqlen_k->scalar_type() == at::ScalarType::Int);
1220 p.seqlen_k_ptr = (const int32_t*)seqlen_k->const_data_ptr();
1221 }
1222 if (window_size.has_value()) {
1223 p.window_size = *window_size;
1224 }
1225 p.scale = sdp::calculate_scale(query, scale).as_float_unchecked();
1226
1227 ASSIGN_CHECK_OVERFLOW(p.q_strideB, query.stride(0));
1228 ASSIGN_CHECK_OVERFLOW(p.k_strideB, key.stride(0));
1229 ASSIGN_CHECK_OVERFLOW(p.v_strideB, value.stride(0));
1230 ASSIGN_CHECK_OVERFLOW(p.q_strideM, query.stride(1));
1231 ASSIGN_CHECK_OVERFLOW(p.k_strideM, key.stride(1));
1232 ASSIGN_CHECK_OVERFLOW(p.v_strideM, value.stride(1));
1233 ASSIGN_CHECK_OVERFLOW(p.q_strideH, query.stride(2));
1234 ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2));
1235 ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2));
1236 ASSIGN_CHECK_OVERFLOW(p.o_strideM, res.stride(1));
1237
1238 if (bias.has_value()) {
1239 CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA((*bias));
1240 TORCH_CHECK(
1241 bias->scalar_type() == CutlassToAtenDtype<scalar_t>::atScalarType(),
1242 "invalid dtype for bias - should match query's dtype");
1243 p.attn_bias_ptr = (const scalar_t*)bias->const_data_ptr();
1244
1245 TORCH_CHECK(bias->dim() == 4, "Bias expected in BMHK format");
1246 TORCH_CHECK(
1247 bias->size(0) == query.size(0),
1248 "attn_bias: wrong shape (batch dimension)");
1249 TORCH_CHECK(
1250 bias->size(1) == query.size(2),
1251 "attn_bias: wrong shape (head dimension)");
1252 TORCH_CHECK(
1253 bias->size(2) == query.size(1),
1254 "attn_bias: wrong shape (seqlenQ dimension)");
1255 TORCH_CHECK(
1256 bias->size(3) == key.size(1),
1257 "attn_bias: wrong shape (seqlenKV dimension)");
1258 ASSIGN_CHECK_OVERFLOW(p.bias_strideB, bias->stride(0));
1259 ASSIGN_CHECK_OVERFLOW(p.bias_strideH, bias->stride(1));
1260 ASSIGN_CHECK_OVERFLOW(p.bias_strideM, bias->stride(2));
1261 TORCH_CHECK(
1262 bias->stride(3) == 1,
1263 "attn_bias: wrong alignment (last dimension must be contiguous)");
1264 }
1265
1266 p.use_dropout = use_dropout;
1267 if (p.use_dropout) {
1268 p.rng_engine_inputs = philox_state;
1269 p.dropout_prob = dropout_p;
1270 p.seed = seed_t.data_ptr<int64_t>();
1271 p.extragraph_offset = offset_t.data_ptr<int64_t>();
1272 }
1273
1274 if (smem_bytes > 0xc000) {
1275 auto err = cudaFuncSetAttribute(
1276 kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
1277 TORCH_CHECK(
1278 err != cudaErrorInvalidValue,
1279 "This GPU does not have enough shared-memory (kernel requires ",
1280 smem_bytes / 1024,
1281 " kb)");
1282 AT_CUDA_CHECK(err);
1283 }
1284 auto blocks = p.getBlocksGrid();
1285 if (blocks.x * blocks.y * blocks.z == 0 || key.size(1) == 0) {
1286 res.zero_();
1287 return;
1288 }
1289 Kernel::check_supported(p);
1290 kernel_fn<<<blocks, p.getThreadsGrid(), smem_bytes, stream>>>(p);
1291 };
1292
1293 // Dispatch to the right kernel
1294 DISPATCH_TYPES(query, ([&]() {
1295 dispatch_cutlassF<scalar_t>(launchKernel, computeCapability);
1296 }));
1297 TORCH_CHECK(kernel_launched, "cutlassF: no kernel found to launch!");
1298 AT_CUDA_CHECK(cudaGetLastError());
1299
1300 #endif // USE_ROCM
1301 return std::make_tuple(
1302 std::move(res),
1303 std::move(logsumexp),
1304 std::move(seed_t),
1305 std::move(offset_t),
1306 max_seqlen_q,
1307 // TODO: why isn't this being set in the kernel?
1308 max_seqlen_k_.has_value() ? max_seqlen_k_.value() : max_seqlen_k);
1309 #endif
1310 TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.")
1311 return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}, 0, 0);
1312 }
1313
triton_scaled_dot_attention(const Tensor & q,const Tensor & k,const Tensor & v,double dropout_p)1314 Tensor triton_scaled_dot_attention(const Tensor& q, const Tensor& k, const Tensor& v, double dropout_p){
1315 TORCH_CHECK(false, "This operator should be overridden in python before use");
1316 return at::Tensor();
1317 }
1318
1319 REGISTER_CUDA_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cuda);
1320
1321 #if defined(USE_MEM_EFF_ATTENTION) and !defined(USE_ROCM)
1322 namespace {
1323 /**
1324 * simple kernel that populates a tensor with rand uniform values.
1325 * currently only used for testing purposes, not much attention
1326 * is paid to performance.
1327 *
1328 * problem is partitioned as follows:
1329 * - (batch, head) is given by block coordinates
1330 * - each thread handles a row for a given (batch, head)
1331 */
1332 template <typename mask_t>
rand_uniform_kernel(int64_t n_heads,int64_t n_queries,int64_t n_keys,float dropout_prob,at::PhiloxCudaState rng_engine_inputs,mask_t * mask_out,int64_t mask_numel)1333 __global__ void rand_uniform_kernel(
1334 int64_t n_heads,
1335 int64_t n_queries,
1336 int64_t n_keys,
1337 float dropout_prob,
1338 at::PhiloxCudaState rng_engine_inputs,
1339 mask_t* mask_out,
1340 int64_t mask_numel) {
1341 const int64_t batch_id = blockIdx.x;
1342 const int64_t head_id = blockIdx.y;
1343 const int64_t query_idx = threadIdx.x;
1344
1345 const auto seeds = at::cuda::philox::unpack(rng_engine_inputs);
1346
1347 const int dropout_seq_start = batch_id * (n_heads * n_queries * n_keys) +
1348 head_id * (n_queries * n_keys);
1349 const int64_t query_start_idx = query_idx * n_keys;
1350
1351 curandStatePhilox4_32_10_t curand_state;
1352 curand_init(
1353 std::get<0>(seeds),
1354 0,
1355 std::get<1>(seeds) + dropout_seq_start + query_start_idx,
1356 &curand_state);
1357
1358 for (int key_start_idx = 0; key_start_idx < n_keys; key_start_idx += 4) {
1359 float4 rand_quad = curand_uniform4(&curand_state);
1360
1361 #pragma unroll
1362 for (int i = 0; i < 4; ++i) {
1363 const int64_t linear_idx = dropout_seq_start + query_start_idx + key_start_idx + i;
1364 if (linear_idx < mask_numel) {
1365 mask_out[linear_idx] = (&rand_quad.x)[i];
1366 }
1367 }
1368 }
1369 }
1370 } // namespace
1371 #endif // defined(USE_MEM_EFF_ATTENTION) and !defined(USE_ROCM)
1372 /**
1373 * fill tensor with random uniform values. only used for testing, not much
1374 * attention is paid to performance
1375 */
_fill_mem_eff_dropout_mask_(Tensor & self,double dropout_p,const int64_t seed,const int64_t offset)1376 at::Tensor& _fill_mem_eff_dropout_mask_(
1377 Tensor& self,
1378 double dropout_p,
1379 const int64_t seed,
1380 const int64_t offset) {
1381 TORCH_CHECK(self.is_contiguous());
1382 TORCH_CHECK(self.dtype() == at::ScalarType::Float);
1383 const int64_t batch_sz = self.size(0);
1384 const int64_t n_heads = self.size(1);
1385 const int64_t n_queries = self.size(2);
1386 const int64_t n_keys = self.size(3);
1387 #if defined(USE_MEM_EFF_ATTENTION)
1388
1389 #ifdef USE_ROCM
1390 using aotriton::v2::flash::debug_fill_dropout_rng;
1391 using sdp::aotriton_adapter::mk_aotensor;
1392 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
1393 hipError_t err; // TODO: Error handling
1394
1395 err = debug_fill_dropout_rng(mk_aotensor(self, "r"),
1396 static_cast<uint64_t>(seed),
1397 static_cast<uint64_t>(offset),
1398 stream);
1399 #else
1400 at::PhiloxCudaState rng_engine_inputs;
1401 rng_engine_inputs = at::PhiloxCudaState(seed, offset);
1402 at::cuda::CUDAGuard device_guard(self.device());
1403 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
1404
1405 rand_uniform_kernel<float><<<dim3(batch_sz, n_heads), n_queries, 0, stream>>>(
1406 n_heads,
1407 n_queries,
1408 n_keys,
1409 dropout_p,
1410 rng_engine_inputs,
1411 reinterpret_cast<float*>(self.data_ptr()),
1412 self.numel());
1413 #endif
1414
1415 return self;
1416 #endif
1417 TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.")
1418 return self;
1419 }
1420
1421 } // namespace native
1422 } // namespace at
1423