xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/attention.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/TensorBody.h>
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 #include <ATen/core/Tensor.h>
4 #include <ATen/TensorOperators.h>
5 #include <ATen/AccumulateType.h>
6 #include <ATen/Dispatch.h>
7 #include <ATen/OpMathType.h>
8 #include <ATen/native/DispatchStub.h>
9 #include <ATen/NestedTensorImpl.h>
10 #include <ATen/TensorIndexing.h>
11 #include <ATen/TensorSubclassLikeUtils.h>
12 #include <ATen/native/transformers/attention.h>
13 #include <ATen/native/transformers/sdp_utils_cpp.h>
14 #include <c10/util/typeid.h>
15 #include <c10/core/DeviceType.h>
16 #include <c10/core/SymInt.h>
17 #include <c10/core/SymIntArrayRef.h>
18 #include <c10/util/Logging.h>
19 #include <c10/core/DispatchKey.h>
20 #include <c10/core/DispatchKeySet.h>
21 
22 #include <type_traits>
23 #include <limits>
24 #include <utility>
25 
26 #ifndef AT_PER_OPERATOR_HEADERS
27 #include <ATen/Functions.h>
28 #include <ATen/NativeFunctions.h>
29 #else
30 #include <ATen/ops/_fused_sdp_choice_native.h>
31 #include <ATen/ops/_masked_softmax.h>
32 #include <ATen/ops/_native_multi_head_attention_native.h>
33 #include <ATen/ops/_nested_from_padded.h>
34 #include <ATen/ops/_nested_tensor_softmax_with_shape.h>
35 #include <ATen/ops/_scaled_dot_product_attention_math.h>
36 #include <ATen/ops/_scaled_dot_product_attention_math_native.h>
37 #include <ATen/ops/_scaled_dot_product_efficient_attention.h>
38 #include <ATen/ops/_scaled_dot_product_flash_attention.h>
39 #include <ATen/ops/_scaled_dot_product_flash_attention_backward_native.h>
40 #include <ATen/ops/_scaled_dot_product_flash_attention_native.h>
41 #include <ATen/ops/_scaled_dot_product_cudnn_attention.h>
42 #include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu.h>
43 #include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_native.h>
44 #include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward.h>
45 #include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward_native.h>
46 #include <ATen/ops/_softmax.h>
47 #include <ATen/ops/_transform_bias_rescale_qkv.h>
48 #include <ATen/ops/_transform_bias_rescale_qkv_native.h>
49 #include <ATen/ops/_triton_multi_head_attention_native.h>
50 #include <ATen/ops/_triton_scaled_dot_attention.h>
51 #include <ATen/ops/bmm.h>
52 #include <ATen/ops/cat.h>
53 #include <ATen/ops/chunk_native.h>
54 #include <ATen/ops/dropout.h>
55 #include <ATen/ops/linear_native.h>
56 #include <ATen/ops/matmul.h>
57 #include <ATen/ops/matmul_native.h>
58 #include <ATen/ops/ones.h>
59 #include <ATen/ops/pad.h>
60 #include <ATen/ops/scaled_dot_product_attention_native.h>
61 #include <ATen/ops/softmax.h>
62 #include <ATen/ops/split_native.h>
63 #include <ATen/ops/split_with_sizes_native.h>
64 #include <ATen/ops/zeros.h>
65 #include <ATen/ops/zeros_like.h>
66 #endif
67 
68 #include <ATen/native/nested/NestedTensorTransformerFunctions.h>
69 namespace at {
70 
71 namespace native {
72 
73 DEFINE_DISPATCH(_fused_sdp_choice_stub);
74 
75 DEFINE_DISPATCH(transform_bias_rescale_qkv_stub);
76 DEFINE_DISPATCH(flash_attention_kernel);
77 DEFINE_DISPATCH(flash_attention_backward_kernel);
78 
79 namespace {
80 
gemm_nt(const Tensor & self,const Tensor & other)81 Tensor gemm_nt(const Tensor& self, const Tensor& other) {
82   if (self.is_nested()) {
83     return NestedTensor_matmul(self, other.t());
84   } else {
85     return at::native::matmul(self, other.t());
86   }
87 }
88 
transform_0213(const Tensor & a)89 Tensor transform_0213(const Tensor& a) {
90   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(1));
91   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(3));
92   return a.permute({0, 2, 1, 3})
93       .contiguous()
94       .view({a.size(0), a.size(2), a.size(1) * a.size(3)});
95 }
96 
97 } // namespace
98 
99 
bmm_nt(const Tensor & a,const Tensor & b)100 Tensor bmm_nt(const Tensor& a, const Tensor& b) {
101   auto a_ = a.view({a.size(0) * a.size(1), a.size(2), a.size(3)});
102   auto b_ = b.view({b.size(0) * b.size(1), b.size(2), b.size(3)});
103   auto bt_ = b_.transpose(2, 1);
104   auto c_ = at::bmm(a_, bt_);
105   return c_.view({a.size(0), a.size(1), a.size(2), b.size(2)});
106 }
107 
masked_softmax(Tensor & attn_scores,std::optional<Tensor> attn_mask,const Tensor & query,std::optional<int64_t> mask_type)108 Tensor masked_softmax(
109     Tensor& attn_scores,
110     std::optional<Tensor> attn_mask,
111     const Tensor& query,
112     std::optional<int64_t> mask_type) {
113   if (query.is_nested() && !attn_mask) {
114     return at::_nested_tensor_softmax_with_shape(attn_scores, query);
115   }
116   if (attn_mask && attn_mask->dtype() != at::kBool) {
117     attn_mask = attn_mask->to(at::kBool);
118   }
119   if (attn_mask) {
120     return _masked_softmax(attn_scores, *attn_mask, attn_scores.dim() - 1, mask_type);
121   } else {
122     return _softmax_out(attn_scores, attn_scores, attn_scores.dim() - 1, false);
123   }
124 }
125 
bmm_nn(Tensor & out,const Tensor & a,const Tensor & b)126 Tensor bmm_nn(Tensor& out, const Tensor& a, const Tensor& b) {
127   const std::array<int64_t, 3> newAShape = {
128       a.sizes()[0] * a.sizes()[1], a.sizes()[2], a.sizes()[3]};
129   auto a_ = a.view(newAShape);
130   const std::array<int64_t, 3> newBShape = {
131       b.sizes()[0] * b.sizes()[1], b.sizes()[2], b.sizes()[3]};
132   auto b_ = b.view(newBShape);
133   auto out_ = out.reshape({newAShape[0], newAShape[1], newBShape[2]});
134   auto c_ = at::bmm_out(out_, a_, b_);
135   return c_.view({a.size(0), a.size(1), a.size(2), b.size(3)});
136 }
137 
138 
transform0213_gemm_nt_bias(const Tensor & a,const Tensor & b,const Tensor & c,const Tensor & query)139 Tensor transform0213_gemm_nt_bias(
140     const Tensor& a,
141     const Tensor& b,
142     const Tensor& c,
143     const Tensor& query) {
144   if (query.is_nested()) {
145     at::Tensor nested_a = _nested_from_padded(
146         a, get_nested_tensor_impl(query)->get_nested_sizes(), true);
147     return NestedTensor_times_Tensor_plus_Tensor_addmm(
148         c, nested_a, b.t(), 1, 1);
149   } else {
150     const Tensor a_0213 = transform_0213(a);
151     auto a_ = a_0213.view({a_0213.size(0) * a_0213.size(1), a_0213.size(2)});
152     auto r_ = at::native::linear(a_, b, c);
153     return r_.view({a_0213.size(0), a_0213.size(1), r_.size(1)});
154   }
155 }
156 
debug_assert_shape(int line,const Tensor & t,c10::IntArrayRef shape)157 void debug_assert_shape(int line, const Tensor& t, c10::IntArrayRef shape) {
158   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
159       (size_t)t.dim() == shape.size(),
160       "(called from line ",
161       line,
162       ") ",
163       "expected ",
164       shape.size(),
165       "-D tensor but got ",
166       t.dim());
167   if (t.is_nested()) {
168     return;
169   }
170   for (auto idx : c10::irange(shape.size())) {
171     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
172         shape[idx] == 0 || t.sizes()[idx] == shape[idx],
173         "(called from line ",
174         line,
175         ") ",
176         "expected dim ",
177         idx,
178         " to be ",
179         shape[idx],
180         " but got ",
181         t.sizes()[idx]);
182   }
183 }
184 
qkv_projection(const Tensor & query,const Tensor & key,const Tensor & value,const int64_t embed_dim,const Tensor & qkv_weight)185 Tensor qkv_projection(
186     const Tensor& query,
187     const Tensor& key,
188     const Tensor& value,
189     const int64_t embed_dim,
190     const Tensor& qkv_weight) {
191   // shape: [B, T, 3 x D]
192   Tensor qkv;
193 
194   if (key.is_same(value)) {
195     if (query.is_same(key)) {
196       // self-attention
197       qkv = gemm_nt(query, qkv_weight);
198     } else {
199       // encoder-decoder attention
200       // TODO: is there a more efficient way to set this up?
201       // TODO: can we stay nested insted of using cat? Probably just make a
202       // NestedTensor out of the matmul results or something?
203       auto q_kv_weight_s =
204           at::native::split_with_sizes(qkv_weight, {embed_dim, embed_dim * 2}, 0);
205       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
206           q_kv_weight_s.size() == 2,
207           "expected split to produce 2 tensors but it produced ",
208           q_kv_weight_s.size());
209       auto q = gemm_nt(query, q_kv_weight_s[0]);
210       auto kv = gemm_nt(key, q_kv_weight_s[1]);
211       qkv = at::cat({std::move(q), std::move(kv)}, 2);
212     }
213   } else {
214     auto q_k_v_weight_s = at::native::chunk(qkv_weight, 3, 0);
215     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
216         q_k_v_weight_s.size() == 3,
217         "expected chunk to produce 3 tensors but it produced ",
218         q_k_v_weight_s.size());
219     // TODO: can we stay nested instead of using cat?
220     auto q = gemm_nt(query, q_k_v_weight_s[0]);
221     auto k = gemm_nt(key, q_k_v_weight_s[1]);
222     auto v = gemm_nt(value, q_k_v_weight_s[2]);
223     qkv = at::cat({std::move(q), std::move(k), std::move(v)}, 2);
224   }
225 
226   return qkv;
227 }
228 
229 // compute q = (q + q_bias) / sqrt(dim_per_head), k = k + k_bias, v = v + v_bias
transform_bias_rescale_qkv_cpu(const Tensor & qkv,const Tensor & qkv_bias,const int64_t num_head)230 std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_cpu(
231     const Tensor& qkv,
232     const Tensor& qkv_bias,
233     const int64_t num_head) {
234   auto qkv_ = qkv.is_nested()
235     ? c10::MaybeOwned<Tensor>::owned(qkv.to_padded_tensor(0))
236     : c10::MaybeOwned<Tensor>::borrowed(qkv);
237   auto B = qkv_->size(0);
238   auto T = qkv_->size(1);
239   auto _3D = qkv_->size(2);
240   auto D = _3D / 3;
241   TORCH_CHECK(D % num_head == 0);
242   TORCH_CHECK(_3D % 3 == 0);
243   const auto dim_per_head = D / num_head;
244   auto q_k_v = at::empty({3, B, num_head, T, dim_per_head}, qkv_->options());
245 
246   const auto qkv_contig = qkv_->expect_contiguous();
247   const auto qkv_bias_contig = qkv_bias.expect_contiguous();
248   transform_bias_rescale_qkv_stub(
249       kCPU,
250       qkv_->scalar_type(),
251       q_k_v.data_ptr(),
252       qkv_contig->const_data_ptr(),
253       qkv_bias_contig->const_data_ptr(),
254       B, T, D, num_head);
255   auto q_k_v_s =
256       at::native::split(q_k_v.view({3 * B, num_head, T, dim_per_head}), B, 0);
257   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(q_k_v_s.size() == 3);
258   return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]);
259 }
260 
native_multi_head_attention_cpu(const Tensor & query,const Tensor & key,const Tensor & value,const int64_t embed_dim,const int64_t num_head,const Tensor & qkv_weight,const Tensor & qkv_bias,const Tensor & proj_weight,const Tensor & proj_bias,const std::optional<Tensor> & mask,bool need_weights,bool average_attn_weights,const std::optional<int64_t> mask_type)261 std::tuple<Tensor, Tensor> native_multi_head_attention_cpu(
262     const Tensor& query,
263     const Tensor& key,
264     const Tensor& value,
265     const int64_t embed_dim,
266     const int64_t num_head,
267     const Tensor& qkv_weight,
268     const Tensor& qkv_bias,
269     const Tensor& proj_weight,
270     const Tensor& proj_bias,
271     const std::optional<Tensor>& mask,
272     bool need_weights,
273     bool average_attn_weights,
274     const std::optional<int64_t> mask_type) {
275   // query shape: [B, T, D]
276   // qkv_weight shape: [3 * D, D]
277 
278   TORCH_CHECK(
279       !mask || !query.is_nested(),
280       "NestedTensor with mask is not supported yet");
281   const auto D = embed_dim;
282   TORCH_CHECK(
283       query.dim() == 3,
284       "expected 3-D `query`, got ",
285       query.dim(),
286       "-D tensor");
287   TORCH_CHECK(
288       query.is_nested() || query.sizes()[2] == embed_dim,
289       "passed-in embed_dim ",
290       embed_dim,
291       " didn't match last dim of query ",
292       query.sizes()[2]);
293   TORCH_CHECK(
294       key.dim() == 3,
295       "expected 3-D `key`, got ",
296       key.dim(),
297       "-D tensor");
298   TORCH_CHECK(
299       value.dim() == 3,
300       "expected 3-D `value`, got ",
301       value.dim(),
302       "-D tensor");
303   TORCH_CHECK(
304       query.is_nested() || key.is_nested() || value.is_nested() ||
305           (query.sizes() == key.sizes() && key.sizes() == value.sizes()),
306       "expected `query`/`key`/`value` shapes to match");
307   TORCH_CHECK(
308       qkv_weight.dim() == 2,
309       "expected 2-D `qkv_weight`, got ",
310       qkv_weight.dim(),
311       "-D tensor");
312   TORCH_CHECK(
313       D * 3 == qkv_weight.sizes()[0],
314       "expected `qkv_weight` first dim to be 3x embed_dim");
315   TORCH_CHECK(
316       D == qkv_weight.sizes()[1],
317       "expected `qkv_weight` second dim to be embed_Dim");
318   TORCH_CHECK(
319       qkv_bias.dim() == 1,
320       "expected 1-D `qkv_bias`, got ",
321       qkv_bias.dim(),
322       "-D tensor");
323   TORCH_CHECK(
324       qkv_bias.sizes()[0] == 3 * D,
325       "expected `qkv_bias` first dim and first dim of query to be equal");
326   TORCH_CHECK(D % num_head == 0, "`embed_dim` must divide evenly by `num_heads`");
327 
328 #ifndef NDEBUG
329   const auto B = query.is_nested()
330       ? get_nested_tensor_impl(query)->get_nested_sizes().size(0)
331       : query.sizes()[0];
332   auto T = query.is_nested() ? 0 : query.sizes()[1];
333   const auto dim_per_head = D / num_head;
334 #endif
335 
336   // shape: [B, T, 3 x D]
337   auto qkv = qkv_projection(query, key, value, embed_dim, qkv_weight);
338 
339   if (!qkv.is_nested() && qkv.numel() == 0) {
340     if (query.is_nested()) {
341       return std::make_tuple(Tensor(), Tensor());
342     }
343     return std::make_tuple(at::empty_like(query), Tensor());
344   }
345 
346 #ifndef NDEBUG
347   if (!query.is_nested() || !qkv.is_nested()) {
348     if (query.is_nested()) {
349       T = qkv.size(1);
350     }
351     debug_assert_shape(__LINE__, qkv, {B, T, 3 * D});
352   }
353 #endif
354 
355 #ifdef DEBUG_PRINT_EACH_STEP
356   if (!qkv.is_nested()) {
357     std::cerr << "qkv: " << qkv << std::endl;
358   }
359 #endif
360   // shape: 3 x [B, num_head, T, dim_per_head]
361   auto q_k_v = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
362   qkv = Tensor(); // Not used any more, allow free
363   auto& q = std::get<0>(q_k_v);
364   const auto& k = std::get<1>(q_k_v);
365   const auto& v = std::get<2>(q_k_v);
366 #ifndef NDEBUG
367   debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head});
368   debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head});
369   debug_assert_shape(__LINE__, v, {B, num_head, T, dim_per_head});
370 #endif
371 #ifdef DEBUG_PRINT_EACH_STEP
372   std::cerr << "q: " << q << std::endl;
373   std::cerr << "k: " << k << std::endl;
374   std::cerr << "v: " << v << std::endl;
375 #endif
376 
377   // shape: [B, num_head, T, T]
378   auto qkt = bmm_nt(q, k);
379   // q & k are dead but cannot be freed because they were packed with v
380 #ifndef NDEBUG
381   debug_assert_shape(__LINE__, qkt, {B, num_head, T, T});
382 #endif
383 #ifdef DEBUG_PRINT_EACH_STEP
384   std::cerr << "qkt: " << qkt << std::endl;
385 #endif
386 
387   // shape: [B, num_head, T, T]
388   // TODO: long-term, have a kernel that works with
389   // NestedTensor directly if there is no mask passed
390   qkt = masked_softmax(qkt, mask, query, mask_type);
391 #ifdef DEBUG_PRINT_EACH_STEP
392   std::cerr << "qkt after softmax: " << qkt << std::endl;
393 #endif
394 
395   // shape: [B, num_head, T, dim_per_head]
396   // reuse storage for q; we're done with it
397   auto attn_ctx = bmm_nn(q, qkt, v);
398   // qkv is not dead; we just reused storage for q!
399   if (!need_weights) {
400     qkt = Tensor();
401   }
402 #ifndef NDEBUG
403   debug_assert_shape(__LINE__, attn_ctx, {B, num_head, T, dim_per_head});
404 #endif
405 #ifdef DEBUG_PRINT_EACH_STEP
406   std::cerr << "attn_ctx: " << attn_ctx << std::endl;
407 #endif
408 
409   // shape: [B, T, D]
410   // Fuse transform_0213 inside
411   auto proj = transform0213_gemm_nt_bias(
412       attn_ctx, proj_weight, proj_bias, query);
413 #ifndef NDEBUG
414   debug_assert_shape(__LINE__, proj, {B, T, D});
415 #endif
416   if (need_weights && average_attn_weights) {
417     // weights are not needed for full transformer, so don't worry too
418     // much about performance -- we implement this just to make use
419     // cases that don't disable need_weights still get some speedup.
420     qkt = qkt.sum(1);
421     qkt /= num_head;
422   }
423   return std::make_tuple(std::move(proj), std::move(qkt));
424 }
425 
_fused_sdp_choice_cpp(const Tensor & query_,const Tensor & key,const Tensor & value,const std::optional<Tensor> & attn_mask_,double dropout_p,bool is_causal,std::optional<double> scale)426 int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value,
427         const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale){
428   sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal};
429   auto backend = sdp::select_sdp_backend_cpp(kernel_params);
430   if (backend == sdp::SDPBackend::error) {
431     TORCH_CHECK(
432         false,
433         "No viable backend for scaled_dot_product_attention was found. ",
434         "This is likely due to turning off both the math kernel and the fused kernels.");
435   }
436   return static_cast<int64_t>(backend);
437 }
438 
439 REGISTER_ARCH_DISPATCH(_fused_sdp_choice_stub, DEFAULT, &_fused_sdp_choice_cpp);
440 REGISTER_AVX2_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
441 REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
442 REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
443 REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
444 
_fused_sdp_choice_meta(const Tensor & query_,const Tensor & key,const Tensor & value,const std::optional<Tensor> & attn_mask_,double dropout_p,bool is_causal,std::optional<double> scale)445 int64_t _fused_sdp_choice_meta(
446     const Tensor& query_,
447     const Tensor& key,
448     const Tensor& value,
449     const std::optional<Tensor>& attn_mask_,
450     double dropout_p,
451     bool is_causal,
452     std::optional<double> scale) {
453   auto query_key_set = query_.key_set();
454 #if defined(USE_ROCM)
455   bool has_rocm = query_key_set.has(c10::DispatchKey::HIP);
456   if (has_rocm) {
457     auto choice_int = _fused_sdp_choice_stub(at::kHIP, query_, key, value, attn_mask_, dropout_p, is_causal, scale);
458     return choice_int;
459   }
460 #else
461   bool has_cuda = query_key_set.has(c10::DispatchKey::CUDA);
462   if (has_cuda) {
463     auto choice_int = _fused_sdp_choice_stub(
464         at::kCUDA,
465         query_,
466         key,
467         value,
468         attn_mask_,
469         dropout_p,
470         is_causal,
471         scale);
472     return choice_int;
473   }
474 #endif
475   return static_cast<int64_t>(sdp::SDPBackend::math);
476 }
477 namespace {
478 
validate_sdpa_input(const Tensor & query_,const Tensor & key,const Tensor & value,const std::optional<Tensor> & attn_mask_,double dropout_p,bool is_causal,std::optional<double> scale)479 inline void validate_sdpa_input(
480     const Tensor& query_,
481     const Tensor& key,
482     const Tensor& value,
483     const std::optional<Tensor>& attn_mask_,
484     double dropout_p,
485     bool is_causal,
486     std::optional<double> scale) {
487   TORCH_CHECK(
488       query_.dtype() == key.dtype() && query_.dtype() == value.dtype(),
489       "Expected query, key, and value to have the same dtype, but got query.dtype: ",
490       query_.dtype(), " key.dtype: ", key.dtype(), " and value.dtype: ", value.dtype(), " instead.");
491   TORCH_CHECK(
492       query_.device() == key.device() && query_.device() == value.device(),
493       "Expected query, key, and value to have the same device type, but got query.device: ",
494       query_.device(), " key.device: ", key.device(), " and value.device: ", value.device(), " instead.");
495   TORCH_CHECK(
496       query_.dim() >= 2 && key.dim() >= 2 && value.dim() >= 2,
497       "Expected query, key, and value to all be  at least 2 dimensional, but got query.dim: ",
498       query_.dim(), " key.dim: ", key.dim(), " and value.dim: ", value.dim(), " instead.");
499   if (attn_mask_.has_value()){
500     auto mask_dtype = attn_mask_->dtype();
501     TORCH_CHECK(mask_dtype == at::kBool || mask_dtype == at::kFloat || mask_dtype == query_.dtype(),
502       "Expected attn_mask dtype to be bool or float or to match query dtype, but got attn_mask.dtype: ",
503       mask_dtype, " and  query.dtype: ", query_.dtype(), " instead.");
504     TORCH_CHECK(
505       !query_.is_nested() && !key.is_nested(),
506       "Scaled_dot_product_attention: Nested tensors for query / key are not supported "
507       "when an explicit attn_mask is set");
508   }
509   return;
510 }
511 // This function is used to produce an attn_mask
512 // in a standard format that can be consumed by both
513 // the math and memory efficient attn_mask implementation
514 //  Args:
515 //    attn_mask: attn_mask of shape (B, L, S) or (L, S) or (B, N_heads, L, S)
convert_boolean_attn_mask(const std::optional<Tensor> & attn_mask,caffe2::TypeMeta dtype)516 std::optional<Tensor> convert_boolean_attn_mask(const std::optional<Tensor>& attn_mask, caffe2::TypeMeta dtype) {
517   // Pass through
518   if(!attn_mask.has_value()){
519     return c10::nullopt;
520   }
521   // Convert boolean mask to additive mask; need to invert mask to indicate what
522   // to mask *out*.
523   if (attn_mask->dtype() == at::kBool) {
524     auto new_attn_mask = at::zeros_like(attn_mask.value(), dtype);
525     // TODO Use the max type of the input and output
526     new_attn_mask.masked_fill_(
527         attn_mask->logical_not(), -std::numeric_limits<double>::infinity());
528     return new_attn_mask;
529   }
530   // Otherwise, attn_mask represents an additive attention tensor
531   return attn_mask;
532 }
533 // Memory Efficient Attention requires a padded attn mask bias
534 // This function pads the attn_mask bias to be a multiple of 16
535 // Then slices the padded bias to the original size
536 // We apply this function to the top level SDPA so that
537 // if padding is done it will be tracked for backward automatically
538 
539 template<int alignment>
aligned_tensor(const at::Tensor & tensor)540 bool aligned_tensor(const at::Tensor& tensor){
541   for(const auto i : c10::irange(tensor.dim() - 1)){
542     if(tensor.sym_stride(i) % alignment != 0){
543       return false;
544     }
545   }
546   return tensor.sym_stride(-1) == 1;
547 }
548 
549 template <int alignment>
pad_bias(const at::Tensor & attn_bias)550 at::Tensor pad_bias(const at::Tensor& attn_bias) {
551   auto last_dim_size = attn_bias.sym_size(-1);
552   auto pad_count = alignment - (last_dim_size % alignment);
553   auto padded_bias = at::pad_symint(attn_bias, {c10::SymInt(0), pad_count});
554   return padded_bias.slice_symint(-1, 0, last_dim_size);
555 }
556 
preprocess_mask(const at::Tensor & mask,const at::Tensor & query,const at::Tensor & key,const at::Tensor & value)557 at::Tensor preprocess_mask(
558     const at::Tensor& mask,
559     const at::Tensor& query,
560     const at::Tensor& key,
561     const at::Tensor& value) {
562   constexpr int mem_eff_alignment = 8;
563   at::Tensor result_mask = mask;
564   if (!aligned_tensor<mem_eff_alignment>(mask)) {
565     result_mask = pad_bias<mem_eff_alignment>(mask);
566   }
567   return result_mask.expand_symint(
568       {query.sym_size(0),
569        query.sym_size(1),
570        query.sym_size(2),
571        key.sym_size(2)});
572 }
573 // FlashAttentionV2 requires that head dimension be a multiple of 8
574 // This was previously done within the kernel, however
575 // This causes the kernel to maybe alias query, key, value
576 // So instead we pad the head_dimensions to be a multiple of 8 in the composite
577 // region
578 template <int alignment_size, bool slice>
pad_last_dim(const at::Tensor & attn_bias)579 at::Tensor pad_last_dim(const at::Tensor& attn_bias) {
580   auto last_dim_size = attn_bias.sym_size(-1);
581   if (last_dim_size % alignment_size == 0) {
582     return attn_bias;
583   }
584   auto pad_count = alignment_size - (last_dim_size % alignment_size);
585   auto padded_bias = at::pad_symint(attn_bias, {c10::SymInt(0), pad_count});
586   if (slice) {
587     return padded_bias.slice_symint(-1, 0, last_dim_size);
588   }
589   return padded_bias;
590 }
591 
post_process_flash_output(at::Tensor out,c10::SymInt const & og_size)592 at::Tensor post_process_flash_output(
593     at::Tensor out,
594     c10::SymInt const& og_size) {
595   if (!out.is_nested() && out.sym_size(-1) != og_size) {
596     out = out.slice_symint(-1, 0, og_size);
597   }
598   return out;
599 }
600 
handle_private_use(const Tensor & query_,const Tensor & key,const Tensor & value,const std::optional<Tensor> & attn_mask_,double dropout_p,bool is_causal,std::optional<double> scale)601 int64_t handle_private_use(const Tensor& query_, const Tensor& key, const Tensor& value,
602     const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale){
603   int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math);
604   try {
605     choice_int = _fused_sdp_choice_stub(query_.device().type(),
606         query_, key, value, attn_mask_, dropout_p, is_causal, scale);
607   } catch(const ::c10::Error& e){
608   }
609   return choice_int;
610 }
611 
should_compute_logsumexp(const Tensor & query,const Tensor & key,const Tensor & value)612 bool should_compute_logsumexp(const Tensor& query, const Tensor& key, const Tensor& value) {
613   const bool any_inputs_require_grad = query.requires_grad() || key.requires_grad() || value.requires_grad();
614   const bool gradmode_enabled = at::GradMode::is_enabled();
615   return any_inputs_require_grad && gradmode_enabled;
616 }
617 
618 } // namespace
619 
620 // Computes scaled dot product attention on query, key and value tensors, using
621 // an optional attention mask if passed, and applying dropout if a probability
622 // greater than 0.0 is specified.
623 //
624 // Args:
625 //     query (Tensor): Query tensor; shape (N, ..., L, E)
626 //     key (Tensor): Key tensor; shape (N, ..., S, E)
627 //     value (Tensor): Value tensor; shape (N, ..., S, E)
628 //     attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights,
629 //         which is (N,..., L, S). Two types of masks are supported.
630 //         A boolean mask where a value of True indicates that the element *should* take part in attention.
631 //         A float mask of the same type as query, key, value that is added to the attention score.
632 //     dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
633 //     need_attn_weights (bool): If true, the second return value will contain the attention weights used;
634 //         otherwise, the second return value is unspecified
635 //     is_causal (bool): If true, assumes causal attention masking; for this case, attn_mask should not be set.
636 //         TODO: Consider removing this flag before promoting this function to the public API. It's possible
637 //         to get specialized support for causal masks (and other types of masking e.g. local attention / block
638 //         sparse masks) via tensor subclassing, allowing for a leaner API.
639 //
640 // Returns a tensor:
641 //     output (Tensor): Attention output; shape (N, ..., L, E)
642 //
643 // Shape legend:
644 //     N: Batch size
645 //     ...: Any number of other batch dimensions (optional)
646 //     S: Source sequence length
647 //     L: Target sequence length
648 //     E: Embedding dimension
scaled_dot_product_attention(const Tensor & query_,const Tensor & key,const Tensor & value,const std::optional<Tensor> & attn_mask_,double dropout_p,bool is_causal,std::optional<double> scale)649 Tensor scaled_dot_product_attention(
650     const Tensor& query_,
651     const Tensor& key,
652     const Tensor& value,
653     const std::optional<Tensor>& attn_mask_,
654     double dropout_p,
655     bool is_causal,
656     std::optional<double> scale) {
657   validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal, scale);
658   int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math);
659   if (_fused_sdp_choice_stub.is_device_supported(query_.device().type())) {
660     choice_int = _fused_sdp_choice_stub(query_.device().type(),
661           query_, key, value, attn_mask_, dropout_p, is_causal, scale);
662   }
663   sdp::SDPBackend backend = static_cast<sdp::SDPBackend>(choice_int);
664   std::optional<Tensor> attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype());
665   switch (backend) {
666     case sdp::SDPBackend::cudnn_attention: {
667       bool compute_logsumexp = should_compute_logsumexp(query_, key, value);
668       auto out_lse_softmax = at::_scaled_dot_product_cudnn_attention(
669           query_, key, value, dropout_p, is_causal, compute_logsumexp, scale);
670       return std::get<0>(out_lse_softmax);
671     }
672     case sdp::SDPBackend::flash_attention: {
673       if(query_.device().type() == DeviceType::CUDA){
674         c10::SymInt og_size = query_.sym_size(-1);
675         Tensor query_padded = pad_last_dim<8, false>(query_);
676         Tensor key_padded = pad_last_dim<8, false>(key);
677         Tensor value_padded = pad_last_dim<8, false>(value);
678         // We need to calculate the scale based off the OG head dim size
679         auto og_scale = sdp::calculate_scale(query_, scale);
680         auto out_lse_softmax = at::_scaled_dot_product_flash_attention(
681             query_padded, key_padded, value_padded, dropout_p, is_causal, false /*return_debug_mask*/, og_scale.as_float_unchecked());
682         return post_process_flash_output(std::get<0>(out_lse_softmax), og_size);
683       }
684       // For the CPU case we do not need to pad the last dim
685       return std::get<0>(at::_scaled_dot_product_flash_attention_for_cpu(
686           query_, key, value, dropout_p, is_causal, attn_mask, scale));
687     }
688     case sdp::SDPBackend::efficient_attention: {
689       bool compute_logsumexp = should_compute_logsumexp(query_, key, value);
690       if (attn_mask.has_value()) {
691         attn_mask.value() = preprocess_mask(attn_mask.value(), query_, key, value);;
692       }
693       auto out_and_lse = at::_scaled_dot_product_efficient_attention(
694           query_, key, value, attn_mask, compute_logsumexp, dropout_p, is_causal, scale);
695       return std::get<0>(out_and_lse);
696     }
697     case sdp::SDPBackend::math:
698       return std::get<0>(at::_scaled_dot_product_attention_math(
699           query_,
700           key,
701           value,
702           attn_mask,
703           dropout_p,
704           is_causal,
705           c10::nullopt, /*dropout_mask*/
706           scale));
707     default:
708       TORCH_CHECK(
709           false,
710           "No viable backend for scaled_dot_product_attention was found.");
711       return Tensor();
712   }
713 }
714 
_scaled_dot_product_attention_math(const Tensor & query_,const Tensor & key,const Tensor & value,const std::optional<Tensor> & attn_mask_,double dropout_p,bool is_causal,const std::optional<Tensor> & dropout_mask,std::optional<double> scale)715 std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
716         const Tensor& query_, const Tensor& key, const Tensor& value,
717         const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal,
718         const std::optional<Tensor>& dropout_mask, std::optional<double> scale) {
719   C10_LOG_API_USAGE_ONCE("torch.sdpa.math_fallback");
720   if (query_.is_nested() || key.is_nested() || value.is_nested()) {
721     TORCH_CHECK(
722         query_.is_contiguous() && key.is_contiguous() &&
723             value.is_contiguous(),
724         "scaled_dot_product_attention: If inputs are nested tensors they must be contiguous");
725   }
726     auto attn_mask = attn_mask_;
727     // Naive, composite implementation defined here.
728 
729     // Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
730     bool is_negative_scaling = scale.has_value() && scale.value() < 0.0;
731     const auto scaling_factor = sdp::calculate_scale(query_, is_negative_scaling ? std::abs(scale.value()) : scale).sqrt();
732 
733     const auto query = query_ * (is_negative_scaling ? c10::SymFloat(0.0) - scaling_factor: scaling_factor);
734     if (is_causal) {
735         TORCH_CHECK(!attn_mask.has_value(),
736                 "_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True");
737         TORCH_CHECK(!query.is_nested() && !key.is_nested(),
738                 "_scaled_dot_product_attention: Nested tensors for query / key are not supported when is_causal=True");
739 
740         // Replace attn_mask with causal mask; lower triangular elements take part in attention.
741         const auto L = query.sym_size(-2), S = key.sym_size(-2);
742         attn_mask = at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril();
743         attn_mask = convert_boolean_attn_mask(attn_mask, query.dtype());
744     }
745     auto attn = at::matmul(query, key.transpose(-2, -1) * scaling_factor);
746     if (attn_mask.has_value()) {
747       if (at::areAnyTensorSubclassLike({attn, *attn_mask})) {
748         attn = attn.add(*attn_mask);
749       } else {
750         attn.add_(*attn_mask);
751       }
752     }
753     attn = at::softmax(attn, -1);
754     if (dropout_p > 0.0) {
755       if (dropout_mask.has_value()) {
756         // In order to validate the correctness of the fused kernels, we need to
757         // use the same dropout mask in order to compare the results.
758         TORCH_WARN_ONCE("Dropout mask should only be used for testing purposes.");
759         attn = attn.masked_fill(dropout_mask->logical_not(), 0.0);
760         auto dropout_scaling = 1.0 / (1 - dropout_p);
761         return std::make_tuple(at::matmul(attn, value * dropout_scaling), attn);
762       } else {
763         attn = at::dropout(attn, dropout_p, true);
764       }
765     }
766 
767     return std::make_tuple(at::matmul(attn, value), attn);
768 }
769 
770 std::tuple<at::Tensor, at::Tensor>
_scaled_dot_product_flash_attention_cpu(const Tensor & query,const Tensor & key,const Tensor & value,double dropout_p,bool is_causal,const std::optional<Tensor> & attn_mask,std::optional<double> scale)771 _scaled_dot_product_flash_attention_cpu(
772     const Tensor& query,
773     const Tensor& key,
774     const Tensor& value,
775     double dropout_p,
776     bool is_causal,
777     const std::optional<Tensor>& attn_mask,
778     std::optional<double> scale) {
779   const auto dtype = query.scalar_type();
780   int64_t batchSize = query.size(0);
781   int64_t qSize = query.size(2);
782   int64_t num_head = query.size(1);
783   int64_t headSize = query.size(3);
784 
785   TORCH_CHECK(c10::isFloatingType(dtype),
786     "scaled_dot_product_attention_flash_attention: Expected data type in FP32, FP64, BF16, FP16, but got ", dtype, " instead.");
787   TORCH_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4,
788     "scaled_dot_product_attention_flash_attention: Accept only 4 dims inputs shape of {B, H, T, K}");
789   TORCH_CHECK(dropout_p == 0.0,
790     "scaled_dot_product_attention_flash_attention: Currently do not support dropout > 0");
791   TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)),
792     "scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size");
793   TORCH_CHECK(!attn_mask.has_value() ||
794           attn_mask.value().scalar_type() == at::kFloat ||
795           dtype == attn_mask.value().scalar_type(),
796     "scaled_dot_product_attention_flash_attention: Attention mask is the same data type as query");
797   TORCH_CHECK(!attn_mask.has_value() ||
798           (attn_mask.value().dim() == 2 || attn_mask.value().dim() == 4),
799     "scaled_dot_product_attention_flash_attention: Attention mask dim in {2, 4}");
800 
801   at::Tensor output = at::empty({batchSize, qSize, num_head, headSize}, query.options());
802   const auto accumulate_dtype = toOpMathType(dtype);
803   at::Tensor logsumexp = at::empty({batchSize, qSize, num_head},
804       query.options().dtype(accumulate_dtype));
805 
806   flash_attention_kernel(kCPU, output, logsumexp,
807       query, key, value, dropout_p, is_causal, attn_mask, scale);
808 
809   output = output.transpose(1, 2);
810   logsumexp = logsumexp.transpose(1, 2);
811 
812   return std::make_tuple(std::move(output), std::move(logsumexp));
813 }
814 
815 std::tuple<at::Tensor, at::Tensor, at::Tensor>
_scaled_dot_product_flash_attention_cpu_backward(const Tensor & grad_out,const Tensor & query,const Tensor & key,const Tensor & value,const Tensor & out,const Tensor & logsumexp,double dropout_p,bool is_causal,const std::optional<Tensor> & attn_mask,std::optional<double> scale)816 _scaled_dot_product_flash_attention_cpu_backward(
817     const Tensor& grad_out,
818     const Tensor& query,
819     const Tensor& key,
820     const Tensor& value,
821     const Tensor& out,
822     const Tensor& logsumexp,
823     double dropout_p,
824     bool is_causal,
825     const std::optional<Tensor>& attn_mask,
826     std::optional<double> scale) {
827   if (!grad_out.defined()) {
828     return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
829   }
830   auto grad_out_t = grad_out.transpose(1, 2);
831   auto q_t = query.transpose(1, 2);
832   auto k_t = key.transpose(1, 2);
833   auto v_t = value.transpose(1, 2);
834   auto o_t = out.transpose(1, 2);
835   auto lse_t = logsumexp.transpose(1, 2);
836 
837   auto grad_q = at::zeros(q_t.sizes(), query.options());
838   auto grad_k = at::zeros(k_t.sizes(), key.options());
839   auto grad_v = at::zeros(v_t.sizes(), value.options());
840 
841   flash_attention_backward_kernel(kCPU, grad_q, grad_k, grad_v,
842       grad_out_t, q_t, k_t, v_t, o_t, lse_t,
843       dropout_p, is_causal, attn_mask, scale);
844 
845   grad_q = grad_q.transpose(1, 2);
846   grad_k = grad_k.transpose(1, 2);
847   grad_v = grad_v.transpose(1, 2);
848 
849   return std::make_tuple(std::move(grad_q), std::move(grad_k), std::move(grad_v));
850 }
851 
triton_multi_head_attention(const Tensor & query,const Tensor & key,const Tensor & value,const int64_t embed_dim,const int64_t num_head,const Tensor & qkv_weight,const Tensor & qkv_bias,const Tensor & proj_weight,const Tensor & proj_bias,const std::optional<Tensor> & mask)852 Tensor triton_multi_head_attention(
853     const Tensor& query,
854     const Tensor& key,
855     const Tensor& value,
856     const int64_t embed_dim,
857     const int64_t num_head,
858     const Tensor& qkv_weight,
859     const Tensor& qkv_bias,
860     const Tensor& proj_weight,
861     const Tensor& proj_bias,
862     const std::optional<Tensor>& mask) {
863   // query shape: [B, T, D]
864   // qkv_weight shape: [3 * D, D]
865   TORCH_CHECK(!mask, "Only causal mask is supported for Triton.");
866 
867   const auto D = embed_dim;
868   TORCH_CHECK(
869       query.dim() == 3,
870       "expected 3-D `query`, got ",
871       query.dim(),
872       "-D tensor");
873   TORCH_CHECK(
874       query.sizes()[2] == embed_dim,
875       "passed-in embed_dim ",
876       embed_dim,
877       " didn't match last dim of query ",
878       query.sizes()[2]);
879   TORCH_CHECK(
880       key.dim() == 3,
881       "expected 3-D `key`, got ",
882       key.dim(),
883       "-D tensor");
884   TORCH_CHECK(
885       value.dim() == 3,
886       "expected 3-D `value`, got ",
887       value.dim(),
888       "-D tensor");
889   TORCH_CHECK(
890           query.sizes() == key.sizes() && key.sizes() == value.sizes(),
891       "expected `query`/`key`/`value` shapes to match");
892   TORCH_CHECK(
893       qkv_weight.dim() == 2,
894       "expected 2-D `qkv_weight`, got ",
895       qkv_weight.dim(),
896       "-D tensor");
897   TORCH_CHECK(
898       D * 3 == qkv_weight.sizes()[0],
899       "expected `qkv_weight` first dim to be 3x embed_dim");
900   TORCH_CHECK(
901       D == qkv_weight.sizes()[1],
902       "expected `qkv_weight` second dim to be embed_Dim");
903 
904 #ifndef NDEBUG
905   const auto B = query.is_nested()
906       ? get_nested_tensor_impl(query)->get_nested_sizes().size(0)
907       : query.sizes()[0];
908   auto T = query.is_nested() ? 0 : query.sizes()[1];
909   const auto dim_per_head = D / num_head;
910 #endif
911 
912   // shape: [B, T, 3 x D]
913   auto qkv = qkv_projection(query, key, value, embed_dim, qkv_weight);
914 
915   // shape: 3 x [B, num_head, T, dim_per_head]
916   auto q_k_v = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
917   qkv = Tensor(); // Not used any more, allow free
918   auto& q = std::get<0>(q_k_v);
919   const auto& k = std::get<1>(q_k_v);
920   const auto& v = std::get<2>(q_k_v);
921 #ifndef NDEBUG
922   debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head});
923   debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head});
924   debug_assert_shape(__LINE__, v, {B, num_head, T, dim_per_head});
925 #endif
926 #ifdef DEBUG_PRINT_EACH_STEP
927   std::cerr << "q: " << q << std::endl;
928   std::cerr << "k: " << k << std::endl;
929   std::cerr << "v: " << v << std::endl;
930 #endif
931 
932   auto attn_ctx = at::_triton_scaled_dot_attention(q, k, v);
933 
934 #ifndef NDEBUG
935   debug_assert_shape(__LINE__, attn_ctx, {B, num_head, T, dim_per_head});
936 #endif
937 #ifdef DEBUG_PRINT_EACH_STEP
938   std::cerr << "attn_ctx: " << attn_ctx << std::endl;
939 #endif
940 
941   // shape: [B, T, D]
942   // Fuse transform_0213 inside
943   auto proj = transform0213_gemm_nt_bias(
944       attn_ctx, proj_weight, proj_bias, query);
945 #ifndef NDEBUG
946   debug_assert_shape(__LINE__, proj, {B, T, D});
947 #endif
948   return proj;
949 }
950 } // namespace native
951 } // namespace at
952