xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/attention.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <type_traits>
3 
4 #include <ATen/core/Tensor.h>
5 #include <ATen/AccumulateType.h>
6 #include <ATen/Dispatch.h>
7 #include <ATen/native/DispatchStub.h>
8 #include <ATen/NestedTensorImpl.h>
9 #include <ATen/TensorAccessor.h>
10 #include <ATen/TensorOperators.h>
11 #include <c10/util/Logging.h>
12 #include <c10/util/bit_cast.h>
13 
14 #include <ATen/cuda/CUDAContext.h>
15 #include <ATen/cuda/CUDAGraphsUtils.cuh>
16 #include <ATen/cuda/detail/KernelUtils.h>
17 #include <ATen/cuda/detail/IndexUtils.cuh>
18 #include <ATen/native/NonSymbolicBC.h>
19 #include <ATen/native/cuda/Loops.cuh>
20 #include <ATen/native/cuda/MemoryAccess.cuh>
21 #include <ATen/native/cuda/PersistentSoftmax.cuh>
22 #include <ATen/native/cuda/block_reduce.cuh>
23 #include <c10/util/Optional.h>
24 
25 #ifndef AT_PER_OPERATOR_HEADERS
26 #include <ATen/Functions.h>
27 #include <ATen/NativeFunctions.h>
28 #else
29 #include <ATen/ops/_efficient_attention_forward.h>
30 #include <ATen/ops/_efficient_attention_forward_native.h>
31 #include <ATen/ops/_fill_mem_eff_dropout_mask_native.h>
32 #include <ATen/ops/_flash_attention_forward.h>
33 #include <ATen/ops/_flash_attention_forward_native.h>
34 #include <ATen/ops/_fused_sdp_choice_native.h>
35 #include <ATen/ops/_masked_softmax.h>
36 #include <ATen/ops/_native_multi_head_attention_native.h>
37 #include <ATen/ops/scaled_dot_product_attention_native.h>
38 #include <ATen/ops/_scaled_dot_product_efficient_attention.h>
39 #include <ATen/ops/_scaled_dot_product_efficient_attention_native.h>
40 #include <ATen/ops/_scaled_dot_product_flash_attention.h>
41 #include <ATen/ops/_scaled_dot_product_flash_attention_native.h>
42 #include <ATen/ops/_softmax.h>
43 #include <ATen/ops/_transform_bias_rescale_qkv.h>
44 #include <ATen/ops/_triton_multi_head_attention_native.h>
45 #include <ATen/ops/_triton_scaled_dot_attention.h>
46 #include <ATen/ops/empty.h>
47 #include <ATen/ops/empty_like.h>
48 #include <ATen/ops/linear.h>
49 #include <ATen/ops/narrow_native.h>
50 #include <ATen/ops/scalar_tensor.h>
51 #include <ATen/ops/scaled_dot_product_attention.h>
52 #include <ATen/ops/split_native.h>
53 #include <ATen/ops/zeros.h>
54 #endif
55 
56 #ifdef __HIP_PLATFORM_AMD__
57 #include <ATen/native/cudnn/hip/MHA.h>
58 #else
59 #include <ATen/native/cudnn/MHA.h>
60 #endif
61 
62 #include <c10/cuda/CUDAMathCompat.h>
63 
64 #include <ATen/native/transformers/attention.h>
65 #include <ATen/native/nested/NestedTensorUtils.h>
66 #include <ATen/native/nested/NestedTensorTransformerFunctions.h>
67 #include <ATen/native/transformers/cuda/sdp_utils.h>
68 #include <ATen/native/transformers/sdp_utils_cpp.h>
69 
70 #ifdef USE_FLASH_ATTENTION
71 // FlashAttention Specific Imports
72 #include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
73 #endif
74 #ifdef USE_MEM_EFF_ATTENTION
75 #ifndef USE_ROCM
76 // MemoryEfficient Attention Specific Imports for CUDA
77 #include <ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h>
78 #include <ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h>
79 #include <ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h>
80 #else
81 // MemoryEfficient Attention Specific Imports for ROCM
82 #include <ATen/native/transformers/hip/aotriton_adapter.h>
83 #include <aotriton/flash.h>
84 #include <aotriton/runtime.h>
85 #endif
86 #endif
87 
88 namespace at {
89 
90 namespace native {
91 
92 namespace {
93 
94 
95 static constexpr int TRANSFORM_BIAS_RESCALE_VEC = 4;
96 
97 template <typename scalar_t, typename accscalar_t, bool assume_aligned>
transform_bias_rescale_qkv_kernel(const PackedTensorAccessor64<scalar_t,3,RestrictPtrTraits> qkv,const PackedTensorAccessor64<scalar_t,1,RestrictPtrTraits> qkv_bias,PackedTensorAccessor64<scalar_t,5,RestrictPtrTraits> q_k_v,const scalar_t inv_sqrt_dim_per_head)98 __global__ void transform_bias_rescale_qkv_kernel(
99     // [B, T, 3 * D]
100     const PackedTensorAccessor64<scalar_t, 3, RestrictPtrTraits> qkv,
101     // [3 * D]
102     const PackedTensorAccessor64<scalar_t, 1, RestrictPtrTraits> qkv_bias,
103     // [3, B, NH, T, DH]
104     PackedTensorAccessor64<scalar_t, 5, RestrictPtrTraits> q_k_v,
105     const scalar_t inv_sqrt_dim_per_head) {
106   // warp per DH.
107   // so launch B * NH * T warps.
108   auto NH = q_k_v.size(2);
109   auto T = q_k_v.size(3);
110   auto DH = q_k_v.size(4);
111 
112   auto t = blockIdx.x % T;
113   auto b = blockIdx.x / T;
114 
115   auto D = NH * DH;
116 
117   if (assume_aligned) {
118     constexpr int VEC = TRANSFORM_BIAS_RESCALE_VEC;
119     using LoadT = memory::aligned_vector<scalar_t, VEC>;
120     for (int32_t d_v = threadIdx.x; d_v < D / VEC; d_v += blockDim.x) {
121       auto d = d_v * VEC;
122       auto nh = d / DH;
123       auto dh = d % DH;
124       scalar_t qkv_bias_q[VEC];
125       scalar_t qkv_bias_k[VEC];
126       scalar_t qkv_bias_v[VEC];
127       scalar_t qkv_q[VEC];
128       scalar_t qkv_k[VEC];
129       scalar_t qkv_v[VEC];
130 
131       // Here we require D % VEC == 0 for these vectorized loads.
132       *reinterpret_cast<LoadT*>(&qkv_bias_q) =
133           *reinterpret_cast<const LoadT*>(&qkv_bias[d + 0 * D]);
134       *reinterpret_cast<LoadT*>(&qkv_bias_k) =
135           *reinterpret_cast<const LoadT*>(&qkv_bias[d + 1 * D]);
136       *reinterpret_cast<LoadT*>(&qkv_bias_v) =
137           *reinterpret_cast<const LoadT*>(&qkv_bias[d + 2 * D]);
138 
139       *reinterpret_cast<LoadT*>(&qkv_q) =
140           *reinterpret_cast<const LoadT*>(&qkv[b][t][d + 0 * D]);
141       *reinterpret_cast<LoadT*>(&qkv_k) =
142           *reinterpret_cast<const LoadT*>(&qkv[b][t][d + 1 * D]);
143       *reinterpret_cast<LoadT*>(&qkv_v) =
144           *reinterpret_cast<const LoadT*>(&qkv[b][t][d + 2 * D]);
145 
146 #pragma unroll
147       // TODO: specialize for float2half2/half2float2?
148       for (auto ii = 0; ii < VEC; ++ii) {
149         qkv_q[ii] = static_cast<scalar_t>(
150             (static_cast<accscalar_t>(qkv_q[ii]) +
151              static_cast<accscalar_t>(qkv_bias_q[ii])) *
152             static_cast<accscalar_t>(inv_sqrt_dim_per_head));
153         qkv_k[ii] = static_cast<scalar_t>(
154             (static_cast<accscalar_t>(qkv_k[ii]) +
155              static_cast<accscalar_t>(qkv_bias_k[ii])));
156         qkv_v[ii] = static_cast<scalar_t>(
157             (static_cast<accscalar_t>(qkv_v[ii]) +
158              static_cast<accscalar_t>(qkv_bias_v[ii])));
159       }
160 
161       // Here we require DH % VEC == 0 for these vectorized stores.
162       *reinterpret_cast<LoadT*>(&q_k_v[0][b][nh][t][dh]) =
163           *reinterpret_cast<const LoadT*>(&qkv_q);
164       *reinterpret_cast<LoadT*>(&q_k_v[1][b][nh][t][dh]) =
165           *reinterpret_cast<const LoadT*>(&qkv_k);
166       *reinterpret_cast<LoadT*>(&q_k_v[2][b][nh][t][dh]) =
167           *reinterpret_cast<const LoadT*>(&qkv_v);
168     }
169   } else {
170     // Same as above, but we can't vectorize memory access.
171     for (int32_t d = threadIdx.x; d < D; d += blockDim.x) {
172       auto nh = d / DH;
173       auto dh = d % DH;
174       scalar_t qkv_bias_q = qkv_bias[d + 0 * D];
175       scalar_t qkv_bias_k = qkv_bias[d + 1 * D];
176       scalar_t qkv_bias_v = qkv_bias[d + 2 * D];
177       scalar_t qkv_q = qkv[b][t][d + 0 * D];
178       scalar_t qkv_k = qkv[b][t][d + 1 * D];
179       scalar_t qkv_v = qkv[b][t][d + 2 * D];
180       qkv_q = static_cast<scalar_t>(
181           (static_cast<accscalar_t>(qkv_q) +
182            static_cast<accscalar_t>(qkv_bias_q)) *
183           static_cast<accscalar_t>(inv_sqrt_dim_per_head));
184       qkv_k = static_cast<scalar_t>(
185           (static_cast<accscalar_t>(qkv_k) +
186            static_cast<accscalar_t>(qkv_bias_k)));
187       qkv_v = static_cast<scalar_t>(
188           (static_cast<accscalar_t>(qkv_v) +
189            static_cast<accscalar_t>(qkv_bias_v)));
190 
191       q_k_v[0][b][nh][t][dh] = qkv_q;
192       q_k_v[1][b][nh][t][dh] = qkv_k;
193       q_k_v[2][b][nh][t][dh] = qkv_v;
194     }
195   }
196 }
197 
198 template <typename scalar_t, typename accscalar_t, bool assume_aligned = false>
transform_bias_rescale_qkv_add_padding_kernel(const PackedTensorAccessor64<scalar_t,1,RestrictPtrTraits> qkv,const PackedTensorAccessor64<scalar_t,1,RestrictPtrTraits> qkv_bias,const int * offsets,const int * input_sizes,PackedTensorAccessor64<scalar_t,5,RestrictPtrTraits> q_k_v,const scalar_t inv_sqrt_dim_per_head)199 __global__ void transform_bias_rescale_qkv_add_padding_kernel(
200     // [B, T, 3 * D], but it's a NestedTensor buffer
201     const PackedTensorAccessor64<scalar_t, 1, RestrictPtrTraits> qkv,
202     // [3 * D]
203     const PackedTensorAccessor64<scalar_t, 1, RestrictPtrTraits> qkv_bias,
204     const int* offsets,
205     const int* input_sizes,
206     // [3, B, NH, T, DH]
207     PackedTensorAccessor64<scalar_t, 5, RestrictPtrTraits> q_k_v,
208     const scalar_t inv_sqrt_dim_per_head) {
209   // warp per DH.
210   // so launch B * NH * T warps.
211   const auto NH = q_k_v.size(2);
212   const auto T = q_k_v.size(3);
213   const auto DH = q_k_v.size(4);
214 
215   const auto t = blockIdx.x % T;
216   const auto b = blockIdx.x / T;
217 
218   const auto D = NH * DH;
219   const auto _3D = 3 * D;
220 
221   const auto offset_for_batch = offsets[b];
222   const auto input_dim = 1;
223   const auto* sizes_i = input_sizes + b * input_dim;
224   if (assume_aligned) {
225     constexpr int VEC = TRANSFORM_BIAS_RESCALE_VEC;
226     using LoadT = memory::aligned_vector<scalar_t, VEC>;
227     for (int32_t d_v = threadIdx.x; d_v < D / VEC; d_v += blockDim.x) {
228       auto d = d_v * VEC;
229       auto nh = d / DH;
230       auto dh = d % DH;
231       scalar_t qkv_bias_q[VEC];
232       scalar_t qkv_bias_k[VEC];
233       scalar_t qkv_bias_v[VEC];
234       scalar_t qkv_q[VEC];
235       scalar_t qkv_k[VEC];
236       scalar_t qkv_v[VEC];
237 
238       const auto first_item_offset = t * _3D + d;
239       const auto last_item_offset = first_item_offset + VEC - 1;
240       const bool first_item_in_bounds = first_item_offset < sizes_i[0];
241       const bool entire_vec_in_bounds = last_item_offset < sizes_i[0];
242 
243       // Here we require D % VEC == 0 for these vectorized loads.
244       *reinterpret_cast<LoadT*>(&qkv_bias_q) =
245           *reinterpret_cast<const LoadT*>(&qkv_bias[d + 0 * D]);
246       *reinterpret_cast<LoadT*>(&qkv_bias_k) =
247           *reinterpret_cast<const LoadT*>(&qkv_bias[d + 1 * D]);
248       *reinterpret_cast<LoadT*>(&qkv_bias_v) =
249           *reinterpret_cast<const LoadT*>(&qkv_bias[d + 2 * D]);
250 
251       if (entire_vec_in_bounds) {
252         const auto offset = offset_for_batch + first_item_offset;
253         *reinterpret_cast<LoadT*>(&qkv_q) =
254             *reinterpret_cast<const LoadT*>(&qkv[offset + 0 * D]);
255         *reinterpret_cast<LoadT*>(&qkv_k) =
256             *reinterpret_cast<const LoadT*>(&qkv[offset + 1 * D]);
257         *reinterpret_cast<LoadT*>(&qkv_v) =
258             *reinterpret_cast<const LoadT*>(&qkv[offset + 2 * D]);
259 #pragma unroll
260         // TODO: specialize for float2half2/half2float2?
261         for (auto ii = 0; ii < VEC; ++ii) {
262           qkv_q[ii] = static_cast<scalar_t>(
263               (static_cast<accscalar_t>(qkv_q[ii]) +
264                static_cast<accscalar_t>(qkv_bias_q[ii])) *
265               static_cast<accscalar_t>(inv_sqrt_dim_per_head));
266           qkv_k[ii] = static_cast<scalar_t>(
267               (static_cast<accscalar_t>(qkv_k[ii]) +
268                static_cast<accscalar_t>(qkv_bias_k[ii])));
269           qkv_v[ii] = static_cast<scalar_t>(
270               (static_cast<accscalar_t>(qkv_v[ii]) +
271                static_cast<accscalar_t>(qkv_bias_v[ii])));
272         }
273       } else if (first_item_in_bounds) {
274         const auto offset = offset_for_batch + first_item_offset;
275         qkv_q[0] = qkv[offset + 0 * D];
276         qkv_k[0] = qkv[offset + 1 * D];
277         qkv_v[0] = qkv[offset + 2 * D];
278         qkv_q[0] = static_cast<scalar_t>(
279               (static_cast<accscalar_t>(qkv_q[0]) +
280                static_cast<accscalar_t>(qkv_bias_q[0])) *
281               static_cast<accscalar_t>(inv_sqrt_dim_per_head));
282         qkv_k[0] = static_cast<scalar_t>(
283             (static_cast<accscalar_t>(qkv_k[0]) +
284                static_cast<accscalar_t>(qkv_bias_k[0])));
285           qkv_v[0] = static_cast<scalar_t>(
286               (static_cast<accscalar_t>(qkv_v[0]) +
287                static_cast<accscalar_t>(qkv_bias_v[0])));
288 #pragma unroll
289         for (auto ii = 1; ii < VEC; ++ii) {
290           const auto loop_offset = offset + ii;
291           if (loop_offset < sizes_i[0]) {
292             qkv_q[ii] = qkv[loop_offset + 0 * D];
293             qkv_k[ii] = qkv[loop_offset + 1 * D];
294             qkv_v[ii] = qkv[loop_offset + 2 * D];
295             qkv_q[ii] = static_cast<scalar_t>(
296                 (static_cast<accscalar_t>(qkv_q[ii]) +
297                  static_cast<accscalar_t>(qkv_bias_q[ii])) *
298                 static_cast<accscalar_t>(inv_sqrt_dim_per_head));
299             qkv_k[ii] = static_cast<scalar_t>(
300                 (static_cast<accscalar_t>(qkv_k[ii]) +
301                  static_cast<accscalar_t>(qkv_bias_k[ii])));
302             qkv_v[ii] = static_cast<scalar_t>(
303                 (static_cast<accscalar_t>(qkv_v[ii]) +
304                  static_cast<accscalar_t>(qkv_bias_v[ii])));
305           } else {
306             qkv_q[ii] = 0;
307             qkv_k[ii] = 0;
308             qkv_v[ii] = 0;
309           }
310         }
311       } else {
312 #pragma unroll
313         for (auto ii = 0; ii < VEC; ++ii) {
314           qkv_q[ii] = 0;
315           qkv_k[ii] = 0;
316           qkv_v[ii] = 0;
317         }
318       }
319 
320       // Here we require DH % VEC == 0 for these vectorized stores.
321       *reinterpret_cast<LoadT*>(&q_k_v[0][b][nh][t][dh]) =
322           *reinterpret_cast<const LoadT*>(&qkv_q);
323       *reinterpret_cast<LoadT*>(&q_k_v[1][b][nh][t][dh]) =
324           *reinterpret_cast<const LoadT*>(&qkv_k);
325       *reinterpret_cast<LoadT*>(&q_k_v[2][b][nh][t][dh]) =
326           *reinterpret_cast<const LoadT*>(&qkv_v);
327     }
328   } else {
329     for (int32_t d = threadIdx.x; d < D; d += blockDim.x) {
330       auto nh = d / DH;
331       auto dh = d % DH;
332       scalar_t qkv_bias_q = qkv_bias[d + 0 * D];
333       scalar_t qkv_bias_k = qkv_bias[d + 1 * D];
334       scalar_t qkv_bias_v = qkv_bias[d + 2 * D];
335 
336       const auto item_offset = t * _3D + d;
337       const bool in_bounds = item_offset < sizes_i[0];
338       scalar_t qkv_q, qkv_k, qkv_v;
339       if (in_bounds) {
340         const auto qkv_offset = offset_for_batch + item_offset;
341         qkv_q = qkv[qkv_offset + 0 * D];
342         qkv_k = qkv[qkv_offset + 1 * D];
343         qkv_v = qkv[qkv_offset + 2 * D];
344         qkv_q = static_cast<scalar_t>(
345             (static_cast<accscalar_t>(qkv_q) +
346              static_cast<accscalar_t>(qkv_bias_q)) *
347             static_cast<accscalar_t>(inv_sqrt_dim_per_head));
348         qkv_k = static_cast<scalar_t>(
349             (static_cast<accscalar_t>(qkv_k) +
350              static_cast<accscalar_t>(qkv_bias_k)));
351         qkv_v = static_cast<scalar_t>(
352             (static_cast<accscalar_t>(qkv_v) +
353              static_cast<accscalar_t>(qkv_bias_v)));
354       } else {
355         qkv_q = 0;
356         qkv_k = 0;
357         qkv_v = 0;
358       }
359 
360       q_k_v[0][b][nh][t][dh] = qkv_q;
361       q_k_v[1][b][nh][t][dh] = qkv_k;
362       q_k_v[2][b][nh][t][dh] = qkv_v;
363     }
364   }
365 }
366 
collapse_dims_1_and_2(const Tensor & sizes)367 Tensor collapse_dims_1_and_2(const Tensor& sizes) {
368   auto sizes_dim1 = at::native::narrow_symint(sizes, 1, 0, 1);
369   auto sizes_dim2 = at::native::narrow_symint(sizes, 1, 1, 1);
370 
371   return (sizes_dim1 * sizes_dim2).contiguous();
372 }
373 
374 } // namespace
375 // compute q = (q + q_bias) / sqrt(dim_per_head), k = k + k_bias, v = v + v_bias
transform_bias_rescale_qkv_cuda(const Tensor & qkv,const Tensor & qkv_bias,const int64_t num_head)376 __host__ std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_cuda(
377     const Tensor& qkv,
378     const Tensor& qkv_bias,
379     const int64_t num_head) {
380   auto B = qkv.is_nested()
381       ? get_nested_tensor_impl(qkv)->get_nested_sizes().size(0)
382       : qkv.size(0);
383   // TODO: calculate this without the std::vector -- NestedTensor_to_mask wants
384   // this too
385   auto T = qkv.is_nested()
386       ? NestedTensor_get_max_size(*get_nested_tensor_impl(qkv))[0]
387       : qkv.size(1);
388   if (qkv.is_nested()) {
389     // Don't mess with non-nested case for now since it's not set up to fiddle
390     // with mask size.
391 
392     // Round T up to next multiple of 8 so as to be able to utilize Tensor
393     // cores. Otherwise, sometimes with padding, *no* row will have the maximum
394     // sequence length and so we'll have a non-divisible-by-8 dimension even if
395     // the model author chose a multiple of 8.
396     T = T + (8 - (T % 8)) % 8;
397   }
398   auto _3D = qkv_bias.size(0);
399   auto D = _3D / 3;
400   TORCH_CHECK(D % num_head == 0);
401   const auto dim_per_head = D / num_head;
402   auto q_k_v = at::empty({3, B, num_head, T, dim_per_head}, qkv_bias.options());
403 #define CALL_KERNEL(assume_aligned)                                        \
404   transform_bias_rescale_qkv_kernel<scalar_t, accscalar_t, assume_aligned> \
405       <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(          \
406           qkv.packed_accessor64<scalar_t, 3, RestrictPtrTraits>(),         \
407           qkv_bias.packed_accessor64<scalar_t, 1, RestrictPtrTraits>(),    \
408           q_k_v.packed_accessor64<scalar_t, 5, RestrictPtrTraits>(),       \
409           1.0 / std::sqrt(static_cast<scalar_t>(dim_per_head)))
410 #define CALL_ADD_PADDING_KERNEL(assume_aligned)                         \
411   transform_bias_rescale_qkv_add_padding_kernel<                        \
412       scalar_t,                                                         \
413       accscalar_t,                                                      \
414       assume_aligned>                                                   \
415       <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(       \
416           nt_qkv_buffer                                          \
417               .packed_accessor64<scalar_t, 1, RestrictPtrTraits>(),     \
418           qkv_bias.packed_accessor64<scalar_t, 1, RestrictPtrTraits>(), \
419           offsets_ptr,                                                  \
420           sizes_ptr,                                                    \
421           q_k_v.packed_accessor64<scalar_t, 5, RestrictPtrTraits>(),    \
422           1.0 / std::sqrt(static_cast<scalar_t>(dim_per_head)))
423 
424   AT_DISPATCH_FLOATING_TYPES_AND2(
425       ScalarType::Half,
426       ScalarType::BFloat16,
427       qkv.scalar_type(),
428       "transform_bias_rescale_qkv",
429       [&] {
430         using accscalar_t = acc_type<scalar_t, true>;
431         auto threads = std::max(
432             std::min<int32_t>(1024, D / TRANSFORM_BIAS_RESCALE_VEC), 1);
433         auto blocks = B * T;
434         const bool aligned =
435             ((dim_per_head % TRANSFORM_BIAS_RESCALE_VEC) == 0) &&
436             ((reinterpret_cast<intptr_t>(qkv_bias.data_ptr()) %
437               TRANSFORM_BIAS_RESCALE_VEC) == 0);
438         if (aligned) {
439           TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
440               D % TRANSFORM_BIAS_RESCALE_VEC == 0,
441               "D = num_heads * dim_per_head, so we should have dim_per_head % "
442               "TRANSFORM_BIAS_RESCALE_VEC == 0 => "
443               "D % TRANSFORM_BIAS_RESCALE_VEC == 0");
444         }
445         if (qkv.is_nested()) {
446           auto* nt_qkv = get_nested_tensor_impl(qkv);
447           const at::Tensor& nt_qkv_buffer = nt_qkv->get_buffer();
448           auto sizes = collapse_dims_1_and_2(nt_qkv->get_nested_sizes());
449           auto offsets =
450               NestedTensor_batch_offsets_from_size_tensor(sizes, sizes.numel());
451           at::native::narrow_symint(offsets, 0, sizes.numel() + 1, sizes.numel())
452               .copy_(sizes.reshape({-1}));
453           auto metadata = offsets.to(at::Device(kCUDA), at::kInt, true, true);
454           const auto offsets_ptr = metadata.data_ptr<int>();
455           const auto sizes_ptr = offsets_ptr + sizes.numel() + 1;
456           const auto input_dim = sizes.sizes()[1];
457           TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input_dim == 1);
458           if (aligned &&
459               ((reinterpret_cast<intptr_t>(qkv.data_ptr()) %
460                 TRANSFORM_BIAS_RESCALE_VEC) == 0)) {
461             CALL_ADD_PADDING_KERNEL(true);
462           } else {
463             CALL_ADD_PADDING_KERNEL(false);
464           }
465         } else if (aligned) {
466           CALL_KERNEL(true);
467         } else {
468           CALL_KERNEL(false);
469         }
470         C10_CUDA_KERNEL_LAUNCH_CHECK();
471       });
472 #undef CALL_ADD_PADDING_KERNEL
473 #undef CALL_KERNEL
474   auto q_k_v_s =
475       at::native::split(q_k_v.view({3 * B, num_head, T, dim_per_head}), B, 0);
476   return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]);
477 }
478 
native_multi_head_attention_cuda(const Tensor & query,const Tensor & key,const Tensor & value,const int64_t embed_dim,const int64_t num_head,const Tensor & qkv_weight,const Tensor & qkv_bias,const Tensor & proj_weight,const Tensor & proj_bias,const std::optional<Tensor> & mask,bool need_weights,bool average_attn_weights,const std::optional<int64_t> mask_type)479 std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
480     const Tensor& query,
481     const Tensor& key,
482     const Tensor& value,
483     const int64_t embed_dim,
484     const int64_t num_head,
485     const Tensor& qkv_weight,
486     const Tensor& qkv_bias,
487     const Tensor& proj_weight,
488     const Tensor& proj_bias,
489     const std::optional<Tensor>& mask,
490     bool need_weights,
491     bool average_attn_weights,
492     const std::optional<int64_t> mask_type) {
493   // query shape: [B, T, D]
494   // qkv_weight shape: [3 * D, D]
495 
496   TORCH_CHECK(
497       !mask || !query.is_nested(),
498       "NestedTensor with mask is not supported yet");
499   const auto D = embed_dim;
500   TORCH_CHECK(
501       query.dim() == 3,
502       "expected 3-D `query`, got ",
503       query.dim(),
504       "-D tensor");
505   TORCH_CHECK(
506       query.is_nested() || query.sizes()[2] == embed_dim,
507       "passed-in embed_dim ",
508       embed_dim,
509       " didn't match last dim of query ",
510       query.sizes()[2]);
511   TORCH_CHECK(
512       key.dim() == 3,
513       "expected 3-D `key`, got ",
514       key.dim(),
515       "-D tensor");
516   TORCH_CHECK(
517       value.dim() == 3,
518       "expected 3-D `value`, got ",
519       value.dim(),
520       "-D tensor");
521   TORCH_CHECK(
522       query.is_nested() || key.is_nested() || value.is_nested() ||
523           (query.sizes() == key.sizes() && key.sizes() == value.sizes()),
524       "expected `query`/`key`/`value` shapes to match");
525   TORCH_CHECK(
526       qkv_weight.dim() == 2,
527       "expected 2-D `qkv_weight`, got ",
528       qkv_weight.dim(),
529       "-D tensor");
530   TORCH_CHECK(
531       D * 3 == qkv_weight.sizes()[0],
532       "expected `qkv_weight` first dim to be 3x embed_dim");
533   TORCH_CHECK(
534       D == qkv_weight.sizes()[1],
535       "expected `qkv_weight` second dim to be embed_Dim");
536   TORCH_CHECK(
537       qkv_bias.dim() == 1,
538       "expected 1-D `qkv_bias`, got ",
539       qkv_bias.dim(),
540       "-D tensor");
541   TORCH_CHECK(
542       qkv_bias.sizes()[0] == 3 * D,
543       "expected `qkv_bias` first dim and first dim of query to be equal");
544   TORCH_CHECK(D % num_head == 0, "`embed_dim` must divide evenly by `num_heads`");
545 
546 #ifndef NDEBUG
547   const auto B = query.is_nested()
548       ? get_nested_tensor_impl(query)->get_nested_sizes().size(0)
549       : query.sizes()[0];
550   auto T = query.is_nested() ? 0 : query.sizes()[1];
551 
552 #endif
553   const auto dim_per_head = D / num_head;
554   if ((query.is_same(key) && key.is_same(value)) && dim_per_head % 8 == 0 && !need_weights) {
555 
556     // We have not done linear projection yet but the input for SDP
557     // Is expected to be 4 dimensional. We "cheaply" create view tensors
558     // That will then be used for checking hot path conditions with select_sd_backend
559     auto q = query.view({query.size(0), -1, num_head, dim_per_head}).transpose(1, 2);
560     auto k = key.view({key.size(0), -1, num_head, dim_per_head}).transpose(1, 2);
561     auto v = value.view({value.size(0), -1, num_head, dim_per_head}).transpose(1, 2);
562 
563     sdp::sdp_params kernel_params{q, k, v, mask, 0.0, false};
564     auto backend = select_sdp_backend(kernel_params);
565     // strides from packed projection for nested tensors when seq_len is 1 will be
566     // and will trigger a contiguous call in the kernel, so we prevent this
567     bool no_seq_len_1_nested = query.is_nested() ? check_for_seq_len_1_nested_tensor(kernel_params, false) : true;
568     // The API for transformer_encoder is a mask of shape (Batch_Size, Seq_len_q)
569     // For mem-eff attention this will cause the expand call to error
570     // For now I am going to turn of that path not have to deal with all the annoying
571     // Mask type shape grossness
572     if (!mask.has_value() && no_seq_len_1_nested &&
573         (backend == sdp::SDPBackend::flash_attention || backend == sdp::SDPBackend::efficient_attention ||
574          backend == sdp::SDPBackend::cudnn_attention)) {
575       auto x = at::linear(query, qkv_weight, qkv_bias);
576       auto chunks = x.chunk(3, -1);
577       auto x_size_0 = x.size(0);
578 
579       chunks[0] = (chunks[0].view({x_size_0, -1, num_head, dim_per_head}))
580                       .transpose(1, 2);
581       chunks[1] = (chunks[1].view({x_size_0, -1, num_head, dim_per_head}))
582                       .transpose(1, 2);
583       chunks[2] = (chunks[2].view({x_size_0, -1, num_head, dim_per_head}))
584                       .transpose(1, 2);
585       auto y = at::scaled_dot_product_attention(
586           chunks[0], chunks[1], chunks[2], mask, 0.0, false, c10::nullopt);
587 
588       auto past_sdp = y.transpose(1, 2).reshape({x_size_0, -1, embed_dim});
589       return std::make_tuple(
590           at::linear(past_sdp, proj_weight, proj_bias), Tensor());
591     }
592     // Returned math or error lets not use it
593   }
594 
595   // shape: [B, T, 3 x D]
596   auto qkv = qkv_projection(query, key, value, embed_dim, qkv_weight);
597 
598   if (!qkv.is_nested() && qkv.numel() == 0) {
599     if (query.is_nested()) {
600       return std::make_tuple(Tensor(), Tensor());
601     }
602     return std::make_tuple(at::empty_like(query), Tensor());
603   }
604 
605 #ifndef NDEBUG
606   if (!query.is_nested() || !qkv.is_nested()) {
607     if (query.is_nested()) {
608       T = qkv.size(1);
609     }
610     debug_assert_shape(__LINE__, qkv, {B, T, 3 * D});
611   }
612 #endif
613 
614 #ifdef DEBUG_PRINT_EACH_STEP
615   if (!qkv.is_nested()) {
616     std::cerr << "qkv: " << qkv << std::endl;
617   }
618 #endif
619   // shape: 3 x [B, num_head, T, dim_per_head]
620   auto q_k_v = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
621   qkv = Tensor(); // Not used any more, allow free
622   auto& q = std::get<0>(q_k_v);
623   const auto& k = std::get<1>(q_k_v);
624   const auto& v = std::get<2>(q_k_v);
625 #ifndef NDEBUG
626   debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head});
627   debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head});
628   debug_assert_shape(__LINE__, v, {B, num_head, T, dim_per_head});
629 #endif
630 #ifdef DEBUG_PRINT_EACH_STEP
631   std::cerr << "q: " << q << std::endl;
632   std::cerr << "k: " << k << std::endl;
633   std::cerr << "v: " << v << std::endl;
634 #endif
635 
636   // shape: [B, num_head, T, T]
637   auto qkt = bmm_nt(q, k);
638   // q & k are dead but cannot be freed because they were packed with v
639 #ifndef NDEBUG
640   debug_assert_shape(__LINE__, qkt, {B, num_head, T, T});
641 #endif
642 #ifdef DEBUG_PRINT_EACH_STEP
643   std::cerr << "qkt: " << qkt << std::endl;
644 #endif
645 
646   // shape: [B, num_head, T, T]
647   // TODO: long-term, have a kernel that works with
648   // NestedTensor directly if there is no mask passed
649   qkt = masked_softmax(qkt, mask, query, mask_type);
650 #ifdef DEBUG_PRINT_EACH_STEP
651   std::cerr << "qkt after softmax: " << qkt << std::endl;
652 #endif
653 
654   // shape: [B, num_head, T, dim_per_head]
655   // reuse storage for q; we're done with it
656   auto attn_ctx = bmm_nn(q, qkt, v);
657   // qkv is not dead; we just reused storage for q!
658   if (!need_weights) {
659     qkt = Tensor();
660   }
661 #ifndef NDEBUG
662   debug_assert_shape(__LINE__, attn_ctx, {B, num_head, T, dim_per_head});
663 #endif
664 #ifdef DEBUG_PRINT_EACH_STEP
665   std::cerr << "attn_ctx: " << attn_ctx << std::endl;
666 #endif
667 
668   // shape: [B, T, D]
669   // Fuse transform_0213 inside
670   auto proj = transform0213_gemm_nt_bias(
671       attn_ctx, proj_weight, proj_bias, query);
672 #ifndef NDEBUG
673   debug_assert_shape(__LINE__, proj, {B, T, D});
674 #endif
675   if (need_weights && average_attn_weights) {
676     // weights are not needed for full transformer, so don't worry too
677     // much about performance -- we implement this just to make use
678     // cases that don't disable need_weights still get some speedup.
679     qkt = qkt.sum(1);
680     qkt /= num_head;
681   }
682   return std::make_tuple(std::move(proj), std::move(qkt));
683 }
_scaled_dot_product_flash_attention_cuda(const Tensor & query,const Tensor & key,const Tensor & value,double dropout_p,bool is_causal,bool return_debug_mask,std::optional<double> scale)684 std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_cuda(
685     const Tensor& query,
686     const Tensor& key,
687     const Tensor& value,
688     double dropout_p,
689     bool is_causal,
690     bool return_debug_mask,
691     std::optional<double> scale) {
692   // Used for tracking usage statistics
693   C10_LOG_API_USAGE_ONCE("torch.sdpa.flash_attention");
694   // Query (Batch x Num_heads x Q_seq_len  x Dim_per_head)
695   // Key   (Batch x Num_heads x KV_seq_len x Dim_per_head)
696   // Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
697 
698   const int64_t max_seqlen_batch_q = query.size(2);
699   const int64_t max_seqlen_batch_k = key.size(2);
700   const int64_t max_seqlen_batch_v = value.size(2);
701   TORCH_CHECK(
702       max_seqlen_batch_k == max_seqlen_batch_v,
703       "Key and Value must have the same sequence length");
704 
705   // Query -> Query(Batch x Q_seq_len  x Num_heads x Dim_per_head)
706   // Key   -> Key  (Batch x KV_seq_len x Num_heads x Dim_per_head)
707   // Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head)
708   Tensor q_t = query.transpose(1, 2);
709   Tensor k_t = key.transpose(1, 2);
710   Tensor v_t = value.transpose(1, 2);
711 
712   auto
713       [output,
714        logsumexp,
715        philox_seed,
716        philox_offset,
717        debug_attn_mask] =
718           at::_flash_attention_forward(
719               q_t,
720               k_t,
721               v_t,
722               c10::nullopt,
723               c10::nullopt,
724               max_seqlen_batch_q,
725               max_seqlen_batch_k,
726               dropout_p,
727               is_causal,
728               return_debug_mask,
729               scale,
730               c10::nullopt,
731               c10::nullopt);
732   // Reshape output to convert nnz to batch_size and seq_len
733   Tensor attention = output.transpose(1,2);
734 
735   return std::make_tuple(attention, logsumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, philox_seed, philox_offset, debug_attn_mask);
736 }
737 
_scaled_dot_product_cudnn_attention_cuda(const Tensor & query,const Tensor & key,const Tensor & value,double dropout_p,bool is_causal,bool training,std::optional<double> scale)738 std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_cuda(
739     const Tensor& query,
740     const Tensor& key,
741     const Tensor& value,
742     double dropout_p,
743     bool is_causal,
744     bool training,
745     std::optional<double> scale) {
746   // Used for tracking usage statistics
747   C10_LOG_API_USAGE_ONCE("torch.sdpa.flash_attention_cudnn");
748   // Query (Batch x Num_heads x Q_seq_len  x Dim_per_head)
749   // Key   (Batch x Num_heads x KV_seq_len x Dim_per_head)
750   // Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
751   const int64_t batch_size = query.size(0);
752   const int64_t num_heads = query.size(1);
753   const int64_t max_seqlen_batch_q = query.size(2);
754   const int64_t head_dim_qk = query.size(3);
755   const int64_t head_dim_v = value.size(3);
756   const int64_t max_seqlen_batch_k = key.size(2);
757   const int64_t max_seqlen_batch_v = value.size(2);
758   TORCH_CHECK(
759       max_seqlen_batch_k == max_seqlen_batch_v,
760       "Key and Value must have the same sequence length");
761 
762   Tensor attention, log_sumexp;
763 
764   auto cudnn_seed = at::zeros({1}, query.options().dtype(kLong));
765   auto cudnn_offset = at::zeros({1}, query.options().dtype(kLong));
766   const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
767 
768   run_cudnn_SDP_fprop(batch_size/*int64_t b*/,
769                       num_heads/*int64_t h*/,
770                       max_seqlen_batch_q/*int64_t s_q*/,
771                       max_seqlen_batch_k/*int64_t s_kv*/,
772                       head_dim_qk/*int64_t d_qk*/,
773                       head_dim_v/*int64_t d_v*/,
774                       softmax_scale/*float scaling_factor*/,
775                       training/* bool */,
776                       is_causal/* bool */,
777                       dropout_p/*double dropout_probability*/,
778                       query/* Tensor q*/,
779                       key/* Tensor k*/,
780                       value/* Tensor v*/,
781                       log_sumexp/*Tensor softmaxstats*/,
782                       attention/*Tensor o*/,
783                       cudnn_seed/*Tensor dropoutseed*/,
784                       cudnn_offset/*Tensor dropoutoffset*/);
785 
786   return std::make_tuple(attention, log_sumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, cudnn_seed, cudnn_offset, Tensor());
787 }
788 
_scaled_dot_product_efficient_attention_cuda(const Tensor & query,const Tensor & key,const Tensor & value,const std::optional<at::Tensor> & attn_bias,bool compute_log_sumexp,double dropout_p,bool is_causal,std::optional<double> scale)789 std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attention_cuda(
790     const Tensor& query,
791     const Tensor& key,
792     const Tensor& value,
793     const std::optional<at::Tensor>& attn_bias,
794     bool compute_log_sumexp,
795     double dropout_p,
796     bool is_causal,
797     std::optional<double> scale) {
798   // Used for tracking usage statistics
799   C10_LOG_API_USAGE_ONCE("torch.sdpa.mem_efficient_attention");
800   // Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head)
801   // Key   -> Key(Batch x KV_seq_len x Num_heads x Dim_per_head)
802   // Value -> Value(Batch x KV_seq_len x  Num_heads x Dim_per_head)
803   Tensor q_t = query.transpose(1, 2);
804   Tensor k_t = key.transpose(1, 2);
805   Tensor v_t = value.transpose(1, 2);
806 
807   sdp::CustomMaskType custom_mask_type = is_causal
808       ? sdp::CustomMaskType::CausalFromTopLeft
809       : sdp::CustomMaskType::NoCustomMask;
810 
811   auto [attention, log_sumexp, seed, offset, max_seqlen_batch_q, max_seqlen_batch_kv] = at::_efficient_attention_forward(
812       q_t,
813       k_t,
814       v_t,
815       attn_bias,
816       c10::nullopt,
817       c10::nullopt,
818       c10::nullopt,
819       c10::nullopt,
820       dropout_p,
821       static_cast<int64_t>(custom_mask_type),
822       compute_log_sumexp,
823       scale);
824 
825   attention = attention.transpose(1, 2);
826   return std::make_tuple(std::move(attention), std::move(log_sumexp), std::move(seed), std::move(offset));
827 }
828 
_fused_sdp_choice_cuda(const Tensor & query_,const Tensor & key,const Tensor & value,const std::optional<Tensor> & attn_mask_,double dropout_p,bool is_causal,std::optional<double> scale)829 int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value,
830         const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale){
831   sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal};
832   auto backend = select_sdp_backend(kernel_params);
833   if (backend == sdp::SDPBackend::error) {
834     TORCH_CHECK(
835         false,
836         "No viable backend for scaled_dot_product_attention was found. ",
837         "This is likely due to turning off both the math kernel and the fused kernels.");
838   }
839   return static_cast<int64_t>(backend);
840 }
841 
842 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>
_flash_attention_forward(const Tensor & query,const Tensor & key,const Tensor & value,const std::optional<Tensor> & cumulative_sequence_length_q,const std::optional<Tensor> & cumulative_sequence_length_k,int64_t max_seqlen_batch_q,int64_t max_seqlen_batch_k,double dropout_p,bool is_causal,bool return_debug_mask,std::optional<double> scale,std::optional<int64_t> window_size_left,std::optional<int64_t> window_size_right,const std::optional<Tensor> & _seqused_k,const std::optional<Tensor> & _alibi_slopes)843 _flash_attention_forward(
844     const Tensor& query,
845     const Tensor& key,
846     const Tensor& value,
847     const std::optional<Tensor>& cumulative_sequence_length_q,
848     const std::optional<Tensor>& cumulative_sequence_length_k,
849     int64_t max_seqlen_batch_q,
850     int64_t max_seqlen_batch_k,
851     double dropout_p,
852     bool is_causal,
853     bool return_debug_mask,
854     std::optional<double> scale,
855     std::optional<int64_t> window_size_left,
856     std::optional<int64_t> window_size_right,
857     const std::optional<Tensor>& _seqused_k,
858     const std::optional<Tensor>& _alibi_slopes
859     ) {
860 #if defined(USE_FLASH_ATTENTION)
861   const auto softmax_scale =
862       sdp::calculate_scale(query, scale).as_float_unchecked();
863   std::optional<Tensor> out = c10::nullopt;
864 
865   std::optional<Tensor> seqused_k = _seqused_k;
866   std::optional<Tensor> alibi_slopes = _alibi_slopes;
867 
868   const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
869   const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1;
870 
871   // We are going to have two paths:
872   // 1. The standard MHA path for dense tensors
873   // 2. The Varseqlen path
874   TORCH_CHECK(
875       cumulative_sequence_length_q.has_value() ==
876           cumulative_sequence_length_k.has_value(),
877       "cumulative_sequence_length_q and cumulative_sequence_length_k must be both set or both not set");
878   Tensor output, q_padded, k_padded, v_padded, logsumexp, output_shape,
879       philox_seed, philox_offset, debug_attn_mask;
880   if (cumulative_sequence_length_q.has_value()) {
881     std::tie(
882         output,
883         q_padded,
884         k_padded,
885         v_padded,
886         logsumexp,
887         philox_seed,
888         philox_offset,
889         debug_attn_mask) =
890         pytorch_flash::mha_varlen_fwd(
891             query,
892             key,
893             value,
894             out,
895             cumulative_sequence_length_q.value(),
896             cumulative_sequence_length_k.value(),
897             seqused_k, /*seqused_k*/
898             alibi_slopes, /*alibi_slopes*/
899             max_seqlen_batch_q,
900             max_seqlen_batch_k,
901             dropout_p,
902             softmax_scale,
903             false /*zero_tensors*/,
904             is_causal,
905             non_null_window_left,
906             non_null_window_right,
907             return_debug_mask,
908             c10::nullopt /*gen_*/);
909   } else {
910     std::tie(
911         output,
912         q_padded,
913         k_padded,
914         v_padded,
915         logsumexp,
916         philox_seed,
917         philox_offset,
918         debug_attn_mask) =
919         pytorch_flash::mha_fwd(
920             query,
921             key,
922             value,
923             out,
924             alibi_slopes,
925             dropout_p,
926             softmax_scale,
927             is_causal,
928             non_null_window_left,
929             non_null_window_right,
930             return_debug_mask, /*return_softmax (this is used for testing)*/
931             c10::nullopt);
932   }
933   debug_attn_mask =
934       return_debug_mask ? debug_attn_mask : at::empty({0}, query.options());
935   return std::make_tuple(
936       std::move(output),
937       std::move(logsumexp),
938       std::move(philox_seed),
939       std::move(philox_offset),
940       std::move(debug_attn_mask));
941 
942 #endif
943   TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.")
944   return std::make_tuple(
945       Tensor(),
946       Tensor(),
947       Tensor(),
948       Tensor(),
949       Tensor());
950 }
951 
_efficient_attention_forward(const at::Tensor & query,const at::Tensor & key,const at::Tensor & value,const std::optional<at::Tensor> & bias,const std::optional<at::Tensor> & seqstart_q,const std::optional<at::Tensor> & seqstart_k,const std::optional<int64_t> max_seqlen_q_,const std::optional<int64_t> max_seqlen_k_,double dropout_p,int64_t custom_mask_type,bool compute_logsumexp,std::optional<double> scale,const std::optional<at::Tensor> & seqlen_k,const std::optional<int64_t> window_size)952 std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_attention_forward(
953     const at::Tensor& query, // [b, seqlen, num_heads, K]
954     const at::Tensor& key, // [b, seqlen, num_heads, K]
955     const at::Tensor& value, // [b, seqlen, num_heads, Kv]
956     const std::optional<at::Tensor>& bias, // [b, num_heads, seqlen, seqlen]
957     // (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the
958     // position of the first query token for batch $b
959     const std::optional<at::Tensor>& seqstart_q,
960     // (Mode 1MHK only) [b+1]: cu_seqlen_k[b] contains the
961     // position of the first key token for batch $b
962     const std::optional<at::Tensor>& seqstart_k,
963     // (Mode 1MHK only) Maximum sequence length across batches
964     const std::optional<int64_t> max_seqlen_q_,
965     const std::optional<int64_t> max_seqlen_k_,
966     double dropout_p, // attention matrix dropout probability
967     int64_t custom_mask_type,
968     bool compute_logsumexp,
969     std::optional<double> scale,
970     const std::optional<at::Tensor>& seqlen_k,
971     const std::optional<int64_t> window_size) {
972 #if defined(USE_MEM_EFF_ATTENTION)
973 // TODO In theory it is possible to compile with _CUDA_ARCH < 5.0 and run on a
974 // machine that is >= 5.0. In practice, this is not a problem but since
975 // this would avoid runtime architecture checks, we should look into it
976 
977   TORCH_CHECK(query.dim() == 4);
978   TORCH_CHECK(key.dim() == 4);
979   TORCH_CHECK(value.dim() == 4);
980 
981   // Batch sizes
982   TORCH_CHECK(query.size(0) == key.size(0));
983   TORCH_CHECK(query.size(0) == value.size(0));
984 
985   // Sequence length
986   TORCH_CHECK(key.size(1) == value.size(1));
987 
988   // Num heads
989   TORCH_CHECK(query.size(2) == key.size(2));
990   TORCH_CHECK(query.size(2) == value.size(2));
991 
992   // Embedding per head
993   TORCH_CHECK(query.size(3) == key.size(3));
994 
995   int64_t max_seqlen_q = 0, max_seqlen_k = 0;
996   TORCH_CHECK(seqstart_q.has_value() == seqstart_k.has_value());
997   if (seqstart_q.has_value()) {
998     TORCH_CHECK(seqstart_q->scalar_type() == at::ScalarType::Int);
999     TORCH_CHECK(seqstart_k->scalar_type() == at::ScalarType::Int);
1000     TORCH_CHECK(seqstart_q->dim() == 1 && seqstart_k->dim() == 1);
1001     CHECK_NOSPARSE_CONTIGUOUS_CUDA((*seqstart_q));
1002     CHECK_NOSPARSE_CONTIGUOUS_CUDA((*seqstart_k));
1003     TORCH_CHECK(seqstart_q->size(0) == seqstart_k->size(0));
1004     TORCH_CHECK(query.size(0) == 1, "cu_seqlen only supports batch_size=1");
1005     TORCH_CHECK(max_seqlen_q_.has_value());
1006     max_seqlen_q = *max_seqlen_q_;
1007     max_seqlen_k = 0; // TODO: is this actually being set inside the kernel anywhere?
1008                       // see https://github.com/pytorch/pytorch/issues/115590s
1009   } else {
1010     max_seqlen_q = query.size(1);
1011     max_seqlen_k = key.size(1);
1012   }
1013 
1014   CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(query);
1015   CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(key);
1016   CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(value);
1017 
1018   at::cuda::CUDAGuard device_guard(query.device());
1019   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
1020 
1021   int64_t B = query.size(0);
1022   int64_t M = query.size(1);
1023   int64_t N = key.size(1);
1024   int64_t num_heads = query.size(-2);
1025   int64_t K = query.size(-1);
1026   int64_t Kv = value.size(-1);
1027 
1028   at::Tensor res;
1029   at::Tensor logsumexp;
1030   at::Tensor seed_t, offset_t;
1031 
1032   const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO;
1033 
1034   // Note [Seed and Offset Device]
1035   // If we are currently in graph capture mode, we need to create the seed and offset tensors on the device.
1036   // This is necessary for CUDA graph-safe random number generation, which requires the seed and offset tensors
1037   // to be single element tensors on device. During graph capture, when the seed and offset tensors are passed
1038   // the pointers act as scratch space for storing the RNG state for the backwards pass.
1039   // When calling backwards, we either construct a PhiloxState with the pointers or the actual values.
1040   // For more information on CUDA graph-safe RNG states, see Note [CUDA Graph-safe RNG states].
1041 
1042   at::PhiloxCudaState philox_state;
1043   const bool in_capture_stream =
1044       at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None;
1045   auto device = in_capture_stream ? at::kCUDA : at::kCPU;
1046   if (use_dropout) {
1047     auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
1048         c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
1049 
1050     // See Note [Acquire lock when using random generators]
1051     std::lock_guard<std::mutex> lock(gen->mutex_);
1052     // if using dropout, we produce 1 random number for each element of the
1053     // attention tensor
1054     philox_state = gen->philox_cuda_state(B * num_heads * M * N);
1055 
1056     if (in_capture_stream) {
1057       // The seed and offset will be populated by the kernel
1058       seed_t = at::empty({}, at::dtype(at::kLong).device(device));
1059       offset_t = at::empty({}, at::dtype(at::kLong).device(device));
1060     } else {
1061       auto [seed, offset] = at::cuda::philox::unpack(philox_state);
1062       seed_t = at::scalar_tensor(
1063           at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong));
1064       offset_t = at::scalar_tensor(
1065           at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong));
1066     }
1067   } else {
1068     // Not using dropout
1069     seed_t = at::empty({}, at::dtype(at::kLong).device(device));
1070     offset_t = at::empty({}, at::dtype(at::kLong).device(device));
1071   }
1072 
1073 #ifdef USE_ROCM
1074   // ROCM Implementation
1075   auto ret = aotriton::v2::flash::check_gpu(stream);
1076   if (hipSuccess != ret) {
1077       TORCH_CHECK(false,
1078                   "[AOTriton] Accelerated SDPA only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx94a:sramecc+:xnack-)")
1079   }
1080 
1081   // AOTriton may accept aligned on logsumexp tensor in the future for better
1082   // performance, but for now it requires compact logsumexp tensor, even if
1083   // compute_logsumexp is false
1084   constexpr int kAlignLSE = 1;
1085   res = at::empty({B, M, num_heads, Kv}, query.options());
1086   logsumexp = at::empty(
1087       { B, num_heads, max_seqlen_q },
1088       query.options().dtype(at::ScalarType::Float));
1089   at::Tensor softmax_lse = logsumexp.view({B * num_heads, max_seqlen_q});
1090   at::Tensor q_t = query.transpose(1, 2);
1091   at::Tensor k_t = key.transpose(1, 2);
1092   at::Tensor v_t = value.transpose(1, 2);
1093   at::Tensor output_t = res.transpose(1, 2);
1094   bool is_causal;
1095   if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
1096     is_causal = true;
1097   } else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
1098     is_causal = false;
1099   } else {
1100     TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
1101   }
1102 
1103   const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
1104 
1105   using aotriton::v2::flash::attn_fwd;
1106   using sdp::aotriton_adapter::mk_aotensor;
1107   aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, aotriton::DType::kFloat16);
1108   at::Tensor softmax_fa_t = at::empty({ 0, 0, 0, 0 }, query.options());
1109   hipError_t err; // TODO: Error handling
1110   err = attn_fwd(mk_aotensor(q_t, "q"),
1111                  mk_aotensor(k_t, "k"),
1112                  mk_aotensor(v_t, "v"),
1113                  bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
1114                  softmax_scale,
1115                  mk_aotensor<2>(softmax_lse, "M"),
1116                  mk_aotensor(output_t, "Out"),
1117                  dropout_p,
1118                  use_dropout ? *seed_t.data_ptr<int64_t>() : 0,
1119                  use_dropout ? *offset_t.data_ptr<int64_t>() : 0,
1120                  mk_aotensor(softmax_fa_t, "encoded_softmax"),
1121                  is_causal,
1122                  stream);
1123   if (!compute_logsumexp) {
1124     // Set the tensor to empty when compute_logsumexp is false
1125     logsumexp = at::empty(
1126         { B * num_heads, max_seqlen_q, 0 },
1127         query.options().dtype(at::ScalarType::Float));
1128   }
1129 #else
1130   // CUDA Implementation
1131   cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());
1132   const int computeCapability = p->major * 10 + p->minor;
1133 
1134   bool kernel_launched = false;
1135   const auto maxShmem = p->sharedMemPerBlockOptin;
1136 
1137   auto launchKernel = [&](auto _k, auto kernel_fn) {
1138     using Kernel = decltype(_k);
1139     using scalar_t = typename Kernel::scalar_t;
1140     (void)_k;
1141 
1142     if (kernel_launched) {
1143       return;
1144     }
1145     // Check if this kernel is compatible
1146     if (!Kernel::kSupportsDropout && use_dropout) {
1147       return;
1148     }
1149     if (!Kernel::kSupportsBias && bias.has_value()) {
1150       return;
1151     }
1152 
1153     if (value.size(3) > Kernel::kMaxK || key.size(3) > Kernel::kMaxK) {
1154       return;
1155     }
1156     // Alignment
1157     if ((query.stride(2) % Kernel::kAlignmentQ) ||
1158         (key.stride(2) % Kernel::kAlignmentK) ||
1159         (value.stride(2) % Kernel::kAlignmentV)) {
1160       return;
1161     }
1162     // Uses too much shmem
1163     size_t smem_bytes = sizeof(typename Kernel::SharedStorage);
1164     if (smem_bytes > maxShmem) {
1165       return;
1166     }
1167     kernel_launched = true;
1168 
1169     res = at::empty(
1170         {B, M, num_heads, Kv},
1171         query.options().dtype(
1172             CutlassToAtenDtype<typename Kernel::output_t>::atScalarType()));
1173 
1174     // NOTE: Should be aligned (by padding) in case M is
1175     // not a good number for loading during backward
1176     constexpr decltype(M) kAlignLSE = Kernel::kAlignLSE;
1177     logsumexp = at::empty(
1178         {seqstart_q.has_value() ? seqstart_q->size(0) - 1 : B,
1179          num_heads,
1180          compute_logsumexp ? ceil_div(max_seqlen_q, kAlignLSE) * kAlignLSE : 0},
1181         query.options().dtype(at::ScalarType::Float));
1182     typename Kernel::Params p;
1183     p.query_ptr = (const scalar_t*)query.const_data_ptr();
1184     p.key_ptr = (const scalar_t*)key.const_data_ptr();
1185     p.value_ptr = (const scalar_t*)value.const_data_ptr();
1186     p.logsumexp_ptr = compute_logsumexp
1187         ? (typename Kernel::lse_scalar_t*)logsumexp.data_ptr()
1188         : nullptr;
1189     at::Tensor output_accum;
1190     if (Kernel::kNeedsOutputAccumulatorBuffer) {
1191       output_accum = at::empty(
1192           {B, M, num_heads, Kv},
1193           query.options().dtype(
1194               CutlassToAtenDtype<
1195                   typename Kernel::output_accum_t>::atScalarType()));
1196       p.output_accum_ptr =
1197           (typename Kernel::output_accum_t*)output_accum.data_ptr();
1198     } else {
1199       p.output_accum_ptr = nullptr;
1200     }
1201     p.output_ptr = (typename Kernel::output_t*)res.data_ptr();
1202 
1203     if (seqstart_q.has_value()) {
1204       p.seqstart_q_ptr = (const int32_t*)seqstart_q->const_data_ptr();
1205       p.seqstart_k_ptr = (const int32_t*)seqstart_k->const_data_ptr();
1206     }
1207 
1208     p.num_heads = num_heads;
1209     p.head_dim = query.size(3);
1210     p.head_dim_value = value.size(3);
1211     p.num_queries = max_seqlen_q;
1212     p.num_keys = max_seqlen_k;
1213     p.num_batches = seqstart_q.has_value() ? seqstart_q->size(0) - 1 : B;
1214     p.custom_mask_type = custom_mask_type;
1215 
1216     p.seqlen_k_ptr = nullptr;
1217     if (seqlen_k.has_value()) {
1218       CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(seqlen_k.value());
1219       TORCH_CHECK(seqlen_k->scalar_type() == at::ScalarType::Int);
1220       p.seqlen_k_ptr = (const int32_t*)seqlen_k->const_data_ptr();
1221     }
1222     if (window_size.has_value()) {
1223       p.window_size = *window_size;
1224     }
1225     p.scale = sdp::calculate_scale(query, scale).as_float_unchecked();
1226 
1227     ASSIGN_CHECK_OVERFLOW(p.q_strideB, query.stride(0));
1228     ASSIGN_CHECK_OVERFLOW(p.k_strideB, key.stride(0));
1229     ASSIGN_CHECK_OVERFLOW(p.v_strideB, value.stride(0));
1230     ASSIGN_CHECK_OVERFLOW(p.q_strideM, query.stride(1));
1231     ASSIGN_CHECK_OVERFLOW(p.k_strideM, key.stride(1));
1232     ASSIGN_CHECK_OVERFLOW(p.v_strideM, value.stride(1));
1233     ASSIGN_CHECK_OVERFLOW(p.q_strideH, query.stride(2));
1234     ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2));
1235     ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2));
1236     ASSIGN_CHECK_OVERFLOW(p.o_strideM, res.stride(1));
1237 
1238     if (bias.has_value()) {
1239       CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA((*bias));
1240       TORCH_CHECK(
1241           bias->scalar_type() == CutlassToAtenDtype<scalar_t>::atScalarType(),
1242           "invalid dtype for bias - should match query's dtype");
1243       p.attn_bias_ptr = (const scalar_t*)bias->const_data_ptr();
1244 
1245       TORCH_CHECK(bias->dim() == 4, "Bias expected in BMHK format");
1246       TORCH_CHECK(
1247           bias->size(0) == query.size(0),
1248           "attn_bias: wrong shape (batch dimension)");
1249       TORCH_CHECK(
1250           bias->size(1) == query.size(2),
1251           "attn_bias: wrong shape (head dimension)");
1252       TORCH_CHECK(
1253           bias->size(2) == query.size(1),
1254           "attn_bias: wrong shape (seqlenQ dimension)");
1255       TORCH_CHECK(
1256           bias->size(3) == key.size(1),
1257           "attn_bias: wrong shape (seqlenKV dimension)");
1258       ASSIGN_CHECK_OVERFLOW(p.bias_strideB, bias->stride(0));
1259       ASSIGN_CHECK_OVERFLOW(p.bias_strideH, bias->stride(1));
1260       ASSIGN_CHECK_OVERFLOW(p.bias_strideM, bias->stride(2));
1261       TORCH_CHECK(
1262           bias->stride(3) == 1,
1263           "attn_bias: wrong alignment (last dimension must be contiguous)");
1264     }
1265 
1266     p.use_dropout = use_dropout;
1267     if (p.use_dropout) {
1268       p.rng_engine_inputs = philox_state;
1269       p.dropout_prob = dropout_p;
1270       p.seed = seed_t.data_ptr<int64_t>();
1271       p.extragraph_offset = offset_t.data_ptr<int64_t>();
1272     }
1273 
1274     if (smem_bytes > 0xc000) {
1275       auto err = cudaFuncSetAttribute(
1276           kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
1277       TORCH_CHECK(
1278           err != cudaErrorInvalidValue,
1279           "This GPU does not have enough shared-memory (kernel requires ",
1280           smem_bytes / 1024,
1281           " kb)");
1282       AT_CUDA_CHECK(err);
1283     }
1284     auto blocks = p.getBlocksGrid();
1285     if (blocks.x * blocks.y * blocks.z == 0 || key.size(1) == 0) {
1286       res.zero_();
1287       return;
1288     }
1289     Kernel::check_supported(p);
1290     kernel_fn<<<blocks, p.getThreadsGrid(), smem_bytes, stream>>>(p);
1291   };
1292 
1293   // Dispatch to the right kernel
1294   DISPATCH_TYPES(query, ([&]() {
1295                    dispatch_cutlassF<scalar_t>(launchKernel, computeCapability);
1296                  }));
1297   TORCH_CHECK(kernel_launched, "cutlassF: no kernel found to launch!");
1298   AT_CUDA_CHECK(cudaGetLastError());
1299 
1300 #endif // USE_ROCM
1301   return std::make_tuple(
1302       std::move(res),
1303       std::move(logsumexp),
1304       std::move(seed_t),
1305       std::move(offset_t),
1306       max_seqlen_q,
1307       // TODO: why isn't this being set in the kernel?
1308       max_seqlen_k_.has_value() ? max_seqlen_k_.value() : max_seqlen_k);
1309 #endif
1310   TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.")
1311   return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}, 0, 0);
1312 }
1313 
triton_scaled_dot_attention(const Tensor & q,const Tensor & k,const Tensor & v,double dropout_p)1314 Tensor triton_scaled_dot_attention(const Tensor& q, const Tensor& k, const Tensor& v, double dropout_p){
1315   TORCH_CHECK(false, "This operator should be overridden in python before use");
1316   return at::Tensor();
1317 }
1318 
1319 REGISTER_CUDA_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cuda);
1320 
1321 #if defined(USE_MEM_EFF_ATTENTION) and !defined(USE_ROCM)
1322 namespace {
1323 /**
1324  * simple kernel that populates a tensor with rand uniform values.
1325  * currently only used for testing purposes, not much attention
1326  * is paid to performance.
1327  *
1328  * problem is partitioned as follows:
1329  * - (batch, head) is given by block coordinates
1330  * - each thread handles a row for a given (batch, head)
1331  */
1332 template <typename mask_t>
rand_uniform_kernel(int64_t n_heads,int64_t n_queries,int64_t n_keys,float dropout_prob,at::PhiloxCudaState rng_engine_inputs,mask_t * mask_out,int64_t mask_numel)1333 __global__ void rand_uniform_kernel(
1334     int64_t n_heads,
1335     int64_t n_queries,
1336     int64_t n_keys,
1337     float dropout_prob,
1338     at::PhiloxCudaState rng_engine_inputs,
1339     mask_t* mask_out,
1340     int64_t mask_numel) {
1341   const int64_t batch_id = blockIdx.x;
1342   const int64_t head_id = blockIdx.y;
1343   const int64_t query_idx = threadIdx.x;
1344 
1345   const auto seeds = at::cuda::philox::unpack(rng_engine_inputs);
1346 
1347   const int dropout_seq_start = batch_id * (n_heads * n_queries * n_keys) +
1348       head_id * (n_queries * n_keys);
1349   const int64_t query_start_idx = query_idx * n_keys;
1350 
1351   curandStatePhilox4_32_10_t curand_state;
1352   curand_init(
1353       std::get<0>(seeds),
1354       0,
1355       std::get<1>(seeds) + dropout_seq_start + query_start_idx,
1356       &curand_state);
1357 
1358   for (int key_start_idx = 0; key_start_idx < n_keys; key_start_idx += 4) {
1359     float4 rand_quad = curand_uniform4(&curand_state);
1360 
1361 #pragma unroll
1362     for (int i = 0; i < 4; ++i) {
1363       const int64_t linear_idx = dropout_seq_start + query_start_idx + key_start_idx + i;
1364       if (linear_idx < mask_numel) {
1365         mask_out[linear_idx] = (&rand_quad.x)[i];
1366       }
1367     }
1368   }
1369 }
1370 } // namespace
1371 #endif // defined(USE_MEM_EFF_ATTENTION) and !defined(USE_ROCM)
1372 /**
1373  * fill tensor with random uniform values. only used for testing, not much
1374  * attention is paid to performance
1375  */
_fill_mem_eff_dropout_mask_(Tensor & self,double dropout_p,const int64_t seed,const int64_t offset)1376 at::Tensor& _fill_mem_eff_dropout_mask_(
1377     Tensor& self,
1378     double dropout_p,
1379     const int64_t seed,
1380     const int64_t offset) {
1381   TORCH_CHECK(self.is_contiguous());
1382   TORCH_CHECK(self.dtype() == at::ScalarType::Float);
1383   const int64_t batch_sz = self.size(0);
1384   const int64_t n_heads = self.size(1);
1385   const int64_t n_queries = self.size(2);
1386   const int64_t n_keys = self.size(3);
1387 #if defined(USE_MEM_EFF_ATTENTION)
1388 
1389 #ifdef USE_ROCM
1390   using aotriton::v2::flash::debug_fill_dropout_rng;
1391   using sdp::aotriton_adapter::mk_aotensor;
1392   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
1393   hipError_t err; // TODO: Error handling
1394 
1395   err = debug_fill_dropout_rng(mk_aotensor(self, "r"),
1396                                static_cast<uint64_t>(seed),
1397                                static_cast<uint64_t>(offset),
1398                                stream);
1399 #else
1400   at::PhiloxCudaState rng_engine_inputs;
1401   rng_engine_inputs = at::PhiloxCudaState(seed, offset);
1402   at::cuda::CUDAGuard device_guard(self.device());
1403   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
1404 
1405   rand_uniform_kernel<float><<<dim3(batch_sz, n_heads), n_queries, 0, stream>>>(
1406       n_heads,
1407       n_queries,
1408       n_keys,
1409       dropout_p,
1410       rng_engine_inputs,
1411       reinterpret_cast<float*>(self.data_ptr()),
1412       self.numel());
1413 #endif
1414 
1415   return self;
1416 #endif
1417   TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.")
1418   return self;
1419 }
1420 
1421 } // namespace native
1422 } // namespace at
1423