1 #include <ATen/core/TensorBody.h>
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 #include <ATen/core/Tensor.h>
4 #include <ATen/TensorOperators.h>
5 #include <ATen/AccumulateType.h>
6 #include <ATen/Dispatch.h>
7 #include <ATen/OpMathType.h>
8 #include <ATen/native/DispatchStub.h>
9 #include <ATen/NestedTensorImpl.h>
10 #include <ATen/TensorIndexing.h>
11 #include <ATen/TensorSubclassLikeUtils.h>
12 #include <ATen/native/transformers/attention.h>
13 #include <ATen/native/transformers/sdp_utils_cpp.h>
14 #include <c10/util/typeid.h>
15 #include <c10/core/DeviceType.h>
16 #include <c10/core/SymInt.h>
17 #include <c10/core/SymIntArrayRef.h>
18 #include <c10/util/Logging.h>
19 #include <c10/core/DispatchKey.h>
20 #include <c10/core/DispatchKeySet.h>
21
22 #include <type_traits>
23 #include <limits>
24 #include <utility>
25
26 #ifndef AT_PER_OPERATOR_HEADERS
27 #include <ATen/Functions.h>
28 #include <ATen/NativeFunctions.h>
29 #else
30 #include <ATen/ops/_fused_sdp_choice_native.h>
31 #include <ATen/ops/_masked_softmax.h>
32 #include <ATen/ops/_native_multi_head_attention_native.h>
33 #include <ATen/ops/_nested_from_padded.h>
34 #include <ATen/ops/_nested_tensor_softmax_with_shape.h>
35 #include <ATen/ops/_scaled_dot_product_attention_math.h>
36 #include <ATen/ops/_scaled_dot_product_attention_math_native.h>
37 #include <ATen/ops/_scaled_dot_product_efficient_attention.h>
38 #include <ATen/ops/_scaled_dot_product_flash_attention.h>
39 #include <ATen/ops/_scaled_dot_product_flash_attention_backward_native.h>
40 #include <ATen/ops/_scaled_dot_product_flash_attention_native.h>
41 #include <ATen/ops/_scaled_dot_product_cudnn_attention.h>
42 #include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu.h>
43 #include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_native.h>
44 #include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward.h>
45 #include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward_native.h>
46 #include <ATen/ops/_softmax.h>
47 #include <ATen/ops/_transform_bias_rescale_qkv.h>
48 #include <ATen/ops/_transform_bias_rescale_qkv_native.h>
49 #include <ATen/ops/_triton_multi_head_attention_native.h>
50 #include <ATen/ops/_triton_scaled_dot_attention.h>
51 #include <ATen/ops/bmm.h>
52 #include <ATen/ops/cat.h>
53 #include <ATen/ops/chunk_native.h>
54 #include <ATen/ops/dropout.h>
55 #include <ATen/ops/linear_native.h>
56 #include <ATen/ops/matmul.h>
57 #include <ATen/ops/matmul_native.h>
58 #include <ATen/ops/ones.h>
59 #include <ATen/ops/pad.h>
60 #include <ATen/ops/scaled_dot_product_attention_native.h>
61 #include <ATen/ops/softmax.h>
62 #include <ATen/ops/split_native.h>
63 #include <ATen/ops/split_with_sizes_native.h>
64 #include <ATen/ops/zeros.h>
65 #include <ATen/ops/zeros_like.h>
66 #endif
67
68 #include <ATen/native/nested/NestedTensorTransformerFunctions.h>
69 namespace at {
70
71 namespace native {
72
73 DEFINE_DISPATCH(_fused_sdp_choice_stub);
74
75 DEFINE_DISPATCH(transform_bias_rescale_qkv_stub);
76 DEFINE_DISPATCH(flash_attention_kernel);
77 DEFINE_DISPATCH(flash_attention_backward_kernel);
78
79 namespace {
80
gemm_nt(const Tensor & self,const Tensor & other)81 Tensor gemm_nt(const Tensor& self, const Tensor& other) {
82 if (self.is_nested()) {
83 return NestedTensor_matmul(self, other.t());
84 } else {
85 return at::native::matmul(self, other.t());
86 }
87 }
88
transform_0213(const Tensor & a)89 Tensor transform_0213(const Tensor& a) {
90 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(1));
91 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(3));
92 return a.permute({0, 2, 1, 3})
93 .contiguous()
94 .view({a.size(0), a.size(2), a.size(1) * a.size(3)});
95 }
96
97 } // namespace
98
99
bmm_nt(const Tensor & a,const Tensor & b)100 Tensor bmm_nt(const Tensor& a, const Tensor& b) {
101 auto a_ = a.view({a.size(0) * a.size(1), a.size(2), a.size(3)});
102 auto b_ = b.view({b.size(0) * b.size(1), b.size(2), b.size(3)});
103 auto bt_ = b_.transpose(2, 1);
104 auto c_ = at::bmm(a_, bt_);
105 return c_.view({a.size(0), a.size(1), a.size(2), b.size(2)});
106 }
107
masked_softmax(Tensor & attn_scores,std::optional<Tensor> attn_mask,const Tensor & query,std::optional<int64_t> mask_type)108 Tensor masked_softmax(
109 Tensor& attn_scores,
110 std::optional<Tensor> attn_mask,
111 const Tensor& query,
112 std::optional<int64_t> mask_type) {
113 if (query.is_nested() && !attn_mask) {
114 return at::_nested_tensor_softmax_with_shape(attn_scores, query);
115 }
116 if (attn_mask && attn_mask->dtype() != at::kBool) {
117 attn_mask = attn_mask->to(at::kBool);
118 }
119 if (attn_mask) {
120 return _masked_softmax(attn_scores, *attn_mask, attn_scores.dim() - 1, mask_type);
121 } else {
122 return _softmax_out(attn_scores, attn_scores, attn_scores.dim() - 1, false);
123 }
124 }
125
bmm_nn(Tensor & out,const Tensor & a,const Tensor & b)126 Tensor bmm_nn(Tensor& out, const Tensor& a, const Tensor& b) {
127 const std::array<int64_t, 3> newAShape = {
128 a.sizes()[0] * a.sizes()[1], a.sizes()[2], a.sizes()[3]};
129 auto a_ = a.view(newAShape);
130 const std::array<int64_t, 3> newBShape = {
131 b.sizes()[0] * b.sizes()[1], b.sizes()[2], b.sizes()[3]};
132 auto b_ = b.view(newBShape);
133 auto out_ = out.reshape({newAShape[0], newAShape[1], newBShape[2]});
134 auto c_ = at::bmm_out(out_, a_, b_);
135 return c_.view({a.size(0), a.size(1), a.size(2), b.size(3)});
136 }
137
138
transform0213_gemm_nt_bias(const Tensor & a,const Tensor & b,const Tensor & c,const Tensor & query)139 Tensor transform0213_gemm_nt_bias(
140 const Tensor& a,
141 const Tensor& b,
142 const Tensor& c,
143 const Tensor& query) {
144 if (query.is_nested()) {
145 at::Tensor nested_a = _nested_from_padded(
146 a, get_nested_tensor_impl(query)->get_nested_sizes(), true);
147 return NestedTensor_times_Tensor_plus_Tensor_addmm(
148 c, nested_a, b.t(), 1, 1);
149 } else {
150 const Tensor a_0213 = transform_0213(a);
151 auto a_ = a_0213.view({a_0213.size(0) * a_0213.size(1), a_0213.size(2)});
152 auto r_ = at::native::linear(a_, b, c);
153 return r_.view({a_0213.size(0), a_0213.size(1), r_.size(1)});
154 }
155 }
156
debug_assert_shape(int line,const Tensor & t,c10::IntArrayRef shape)157 void debug_assert_shape(int line, const Tensor& t, c10::IntArrayRef shape) {
158 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
159 (size_t)t.dim() == shape.size(),
160 "(called from line ",
161 line,
162 ") ",
163 "expected ",
164 shape.size(),
165 "-D tensor but got ",
166 t.dim());
167 if (t.is_nested()) {
168 return;
169 }
170 for (auto idx : c10::irange(shape.size())) {
171 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
172 shape[idx] == 0 || t.sizes()[idx] == shape[idx],
173 "(called from line ",
174 line,
175 ") ",
176 "expected dim ",
177 idx,
178 " to be ",
179 shape[idx],
180 " but got ",
181 t.sizes()[idx]);
182 }
183 }
184
qkv_projection(const Tensor & query,const Tensor & key,const Tensor & value,const int64_t embed_dim,const Tensor & qkv_weight)185 Tensor qkv_projection(
186 const Tensor& query,
187 const Tensor& key,
188 const Tensor& value,
189 const int64_t embed_dim,
190 const Tensor& qkv_weight) {
191 // shape: [B, T, 3 x D]
192 Tensor qkv;
193
194 if (key.is_same(value)) {
195 if (query.is_same(key)) {
196 // self-attention
197 qkv = gemm_nt(query, qkv_weight);
198 } else {
199 // encoder-decoder attention
200 // TODO: is there a more efficient way to set this up?
201 // TODO: can we stay nested insted of using cat? Probably just make a
202 // NestedTensor out of the matmul results or something?
203 auto q_kv_weight_s =
204 at::native::split_with_sizes(qkv_weight, {embed_dim, embed_dim * 2}, 0);
205 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
206 q_kv_weight_s.size() == 2,
207 "expected split to produce 2 tensors but it produced ",
208 q_kv_weight_s.size());
209 auto q = gemm_nt(query, q_kv_weight_s[0]);
210 auto kv = gemm_nt(key, q_kv_weight_s[1]);
211 qkv = at::cat({std::move(q), std::move(kv)}, 2);
212 }
213 } else {
214 auto q_k_v_weight_s = at::native::chunk(qkv_weight, 3, 0);
215 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
216 q_k_v_weight_s.size() == 3,
217 "expected chunk to produce 3 tensors but it produced ",
218 q_k_v_weight_s.size());
219 // TODO: can we stay nested instead of using cat?
220 auto q = gemm_nt(query, q_k_v_weight_s[0]);
221 auto k = gemm_nt(key, q_k_v_weight_s[1]);
222 auto v = gemm_nt(value, q_k_v_weight_s[2]);
223 qkv = at::cat({std::move(q), std::move(k), std::move(v)}, 2);
224 }
225
226 return qkv;
227 }
228
229 // compute q = (q + q_bias) / sqrt(dim_per_head), k = k + k_bias, v = v + v_bias
transform_bias_rescale_qkv_cpu(const Tensor & qkv,const Tensor & qkv_bias,const int64_t num_head)230 std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_cpu(
231 const Tensor& qkv,
232 const Tensor& qkv_bias,
233 const int64_t num_head) {
234 auto qkv_ = qkv.is_nested()
235 ? c10::MaybeOwned<Tensor>::owned(qkv.to_padded_tensor(0))
236 : c10::MaybeOwned<Tensor>::borrowed(qkv);
237 auto B = qkv_->size(0);
238 auto T = qkv_->size(1);
239 auto _3D = qkv_->size(2);
240 auto D = _3D / 3;
241 TORCH_CHECK(D % num_head == 0);
242 TORCH_CHECK(_3D % 3 == 0);
243 const auto dim_per_head = D / num_head;
244 auto q_k_v = at::empty({3, B, num_head, T, dim_per_head}, qkv_->options());
245
246 const auto qkv_contig = qkv_->expect_contiguous();
247 const auto qkv_bias_contig = qkv_bias.expect_contiguous();
248 transform_bias_rescale_qkv_stub(
249 kCPU,
250 qkv_->scalar_type(),
251 q_k_v.data_ptr(),
252 qkv_contig->const_data_ptr(),
253 qkv_bias_contig->const_data_ptr(),
254 B, T, D, num_head);
255 auto q_k_v_s =
256 at::native::split(q_k_v.view({3 * B, num_head, T, dim_per_head}), B, 0);
257 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(q_k_v_s.size() == 3);
258 return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]);
259 }
260
native_multi_head_attention_cpu(const Tensor & query,const Tensor & key,const Tensor & value,const int64_t embed_dim,const int64_t num_head,const Tensor & qkv_weight,const Tensor & qkv_bias,const Tensor & proj_weight,const Tensor & proj_bias,const std::optional<Tensor> & mask,bool need_weights,bool average_attn_weights,const std::optional<int64_t> mask_type)261 std::tuple<Tensor, Tensor> native_multi_head_attention_cpu(
262 const Tensor& query,
263 const Tensor& key,
264 const Tensor& value,
265 const int64_t embed_dim,
266 const int64_t num_head,
267 const Tensor& qkv_weight,
268 const Tensor& qkv_bias,
269 const Tensor& proj_weight,
270 const Tensor& proj_bias,
271 const std::optional<Tensor>& mask,
272 bool need_weights,
273 bool average_attn_weights,
274 const std::optional<int64_t> mask_type) {
275 // query shape: [B, T, D]
276 // qkv_weight shape: [3 * D, D]
277
278 TORCH_CHECK(
279 !mask || !query.is_nested(),
280 "NestedTensor with mask is not supported yet");
281 const auto D = embed_dim;
282 TORCH_CHECK(
283 query.dim() == 3,
284 "expected 3-D `query`, got ",
285 query.dim(),
286 "-D tensor");
287 TORCH_CHECK(
288 query.is_nested() || query.sizes()[2] == embed_dim,
289 "passed-in embed_dim ",
290 embed_dim,
291 " didn't match last dim of query ",
292 query.sizes()[2]);
293 TORCH_CHECK(
294 key.dim() == 3,
295 "expected 3-D `key`, got ",
296 key.dim(),
297 "-D tensor");
298 TORCH_CHECK(
299 value.dim() == 3,
300 "expected 3-D `value`, got ",
301 value.dim(),
302 "-D tensor");
303 TORCH_CHECK(
304 query.is_nested() || key.is_nested() || value.is_nested() ||
305 (query.sizes() == key.sizes() && key.sizes() == value.sizes()),
306 "expected `query`/`key`/`value` shapes to match");
307 TORCH_CHECK(
308 qkv_weight.dim() == 2,
309 "expected 2-D `qkv_weight`, got ",
310 qkv_weight.dim(),
311 "-D tensor");
312 TORCH_CHECK(
313 D * 3 == qkv_weight.sizes()[0],
314 "expected `qkv_weight` first dim to be 3x embed_dim");
315 TORCH_CHECK(
316 D == qkv_weight.sizes()[1],
317 "expected `qkv_weight` second dim to be embed_Dim");
318 TORCH_CHECK(
319 qkv_bias.dim() == 1,
320 "expected 1-D `qkv_bias`, got ",
321 qkv_bias.dim(),
322 "-D tensor");
323 TORCH_CHECK(
324 qkv_bias.sizes()[0] == 3 * D,
325 "expected `qkv_bias` first dim and first dim of query to be equal");
326 TORCH_CHECK(D % num_head == 0, "`embed_dim` must divide evenly by `num_heads`");
327
328 #ifndef NDEBUG
329 const auto B = query.is_nested()
330 ? get_nested_tensor_impl(query)->get_nested_sizes().size(0)
331 : query.sizes()[0];
332 auto T = query.is_nested() ? 0 : query.sizes()[1];
333 const auto dim_per_head = D / num_head;
334 #endif
335
336 // shape: [B, T, 3 x D]
337 auto qkv = qkv_projection(query, key, value, embed_dim, qkv_weight);
338
339 if (!qkv.is_nested() && qkv.numel() == 0) {
340 if (query.is_nested()) {
341 return std::make_tuple(Tensor(), Tensor());
342 }
343 return std::make_tuple(at::empty_like(query), Tensor());
344 }
345
346 #ifndef NDEBUG
347 if (!query.is_nested() || !qkv.is_nested()) {
348 if (query.is_nested()) {
349 T = qkv.size(1);
350 }
351 debug_assert_shape(__LINE__, qkv, {B, T, 3 * D});
352 }
353 #endif
354
355 #ifdef DEBUG_PRINT_EACH_STEP
356 if (!qkv.is_nested()) {
357 std::cerr << "qkv: " << qkv << std::endl;
358 }
359 #endif
360 // shape: 3 x [B, num_head, T, dim_per_head]
361 auto q_k_v = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
362 qkv = Tensor(); // Not used any more, allow free
363 auto& q = std::get<0>(q_k_v);
364 const auto& k = std::get<1>(q_k_v);
365 const auto& v = std::get<2>(q_k_v);
366 #ifndef NDEBUG
367 debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head});
368 debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head});
369 debug_assert_shape(__LINE__, v, {B, num_head, T, dim_per_head});
370 #endif
371 #ifdef DEBUG_PRINT_EACH_STEP
372 std::cerr << "q: " << q << std::endl;
373 std::cerr << "k: " << k << std::endl;
374 std::cerr << "v: " << v << std::endl;
375 #endif
376
377 // shape: [B, num_head, T, T]
378 auto qkt = bmm_nt(q, k);
379 // q & k are dead but cannot be freed because they were packed with v
380 #ifndef NDEBUG
381 debug_assert_shape(__LINE__, qkt, {B, num_head, T, T});
382 #endif
383 #ifdef DEBUG_PRINT_EACH_STEP
384 std::cerr << "qkt: " << qkt << std::endl;
385 #endif
386
387 // shape: [B, num_head, T, T]
388 // TODO: long-term, have a kernel that works with
389 // NestedTensor directly if there is no mask passed
390 qkt = masked_softmax(qkt, mask, query, mask_type);
391 #ifdef DEBUG_PRINT_EACH_STEP
392 std::cerr << "qkt after softmax: " << qkt << std::endl;
393 #endif
394
395 // shape: [B, num_head, T, dim_per_head]
396 // reuse storage for q; we're done with it
397 auto attn_ctx = bmm_nn(q, qkt, v);
398 // qkv is not dead; we just reused storage for q!
399 if (!need_weights) {
400 qkt = Tensor();
401 }
402 #ifndef NDEBUG
403 debug_assert_shape(__LINE__, attn_ctx, {B, num_head, T, dim_per_head});
404 #endif
405 #ifdef DEBUG_PRINT_EACH_STEP
406 std::cerr << "attn_ctx: " << attn_ctx << std::endl;
407 #endif
408
409 // shape: [B, T, D]
410 // Fuse transform_0213 inside
411 auto proj = transform0213_gemm_nt_bias(
412 attn_ctx, proj_weight, proj_bias, query);
413 #ifndef NDEBUG
414 debug_assert_shape(__LINE__, proj, {B, T, D});
415 #endif
416 if (need_weights && average_attn_weights) {
417 // weights are not needed for full transformer, so don't worry too
418 // much about performance -- we implement this just to make use
419 // cases that don't disable need_weights still get some speedup.
420 qkt = qkt.sum(1);
421 qkt /= num_head;
422 }
423 return std::make_tuple(std::move(proj), std::move(qkt));
424 }
425
_fused_sdp_choice_cpp(const Tensor & query_,const Tensor & key,const Tensor & value,const std::optional<Tensor> & attn_mask_,double dropout_p,bool is_causal,std::optional<double> scale)426 int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value,
427 const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale){
428 sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal};
429 auto backend = sdp::select_sdp_backend_cpp(kernel_params);
430 if (backend == sdp::SDPBackend::error) {
431 TORCH_CHECK(
432 false,
433 "No viable backend for scaled_dot_product_attention was found. ",
434 "This is likely due to turning off both the math kernel and the fused kernels.");
435 }
436 return static_cast<int64_t>(backend);
437 }
438
439 REGISTER_ARCH_DISPATCH(_fused_sdp_choice_stub, DEFAULT, &_fused_sdp_choice_cpp);
440 REGISTER_AVX2_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
441 REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
442 REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
443 REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
444
_fused_sdp_choice_meta(const Tensor & query_,const Tensor & key,const Tensor & value,const std::optional<Tensor> & attn_mask_,double dropout_p,bool is_causal,std::optional<double> scale)445 int64_t _fused_sdp_choice_meta(
446 const Tensor& query_,
447 const Tensor& key,
448 const Tensor& value,
449 const std::optional<Tensor>& attn_mask_,
450 double dropout_p,
451 bool is_causal,
452 std::optional<double> scale) {
453 auto query_key_set = query_.key_set();
454 #if defined(USE_ROCM)
455 bool has_rocm = query_key_set.has(c10::DispatchKey::HIP);
456 if (has_rocm) {
457 auto choice_int = _fused_sdp_choice_stub(at::kHIP, query_, key, value, attn_mask_, dropout_p, is_causal, scale);
458 return choice_int;
459 }
460 #else
461 bool has_cuda = query_key_set.has(c10::DispatchKey::CUDA);
462 if (has_cuda) {
463 auto choice_int = _fused_sdp_choice_stub(
464 at::kCUDA,
465 query_,
466 key,
467 value,
468 attn_mask_,
469 dropout_p,
470 is_causal,
471 scale);
472 return choice_int;
473 }
474 #endif
475 return static_cast<int64_t>(sdp::SDPBackend::math);
476 }
477 namespace {
478
validate_sdpa_input(const Tensor & query_,const Tensor & key,const Tensor & value,const std::optional<Tensor> & attn_mask_,double dropout_p,bool is_causal,std::optional<double> scale)479 inline void validate_sdpa_input(
480 const Tensor& query_,
481 const Tensor& key,
482 const Tensor& value,
483 const std::optional<Tensor>& attn_mask_,
484 double dropout_p,
485 bool is_causal,
486 std::optional<double> scale) {
487 TORCH_CHECK(
488 query_.dtype() == key.dtype() && query_.dtype() == value.dtype(),
489 "Expected query, key, and value to have the same dtype, but got query.dtype: ",
490 query_.dtype(), " key.dtype: ", key.dtype(), " and value.dtype: ", value.dtype(), " instead.");
491 TORCH_CHECK(
492 query_.device() == key.device() && query_.device() == value.device(),
493 "Expected query, key, and value to have the same device type, but got query.device: ",
494 query_.device(), " key.device: ", key.device(), " and value.device: ", value.device(), " instead.");
495 TORCH_CHECK(
496 query_.dim() >= 2 && key.dim() >= 2 && value.dim() >= 2,
497 "Expected query, key, and value to all be at least 2 dimensional, but got query.dim: ",
498 query_.dim(), " key.dim: ", key.dim(), " and value.dim: ", value.dim(), " instead.");
499 if (attn_mask_.has_value()){
500 auto mask_dtype = attn_mask_->dtype();
501 TORCH_CHECK(mask_dtype == at::kBool || mask_dtype == at::kFloat || mask_dtype == query_.dtype(),
502 "Expected attn_mask dtype to be bool or float or to match query dtype, but got attn_mask.dtype: ",
503 mask_dtype, " and query.dtype: ", query_.dtype(), " instead.");
504 TORCH_CHECK(
505 !query_.is_nested() && !key.is_nested(),
506 "Scaled_dot_product_attention: Nested tensors for query / key are not supported "
507 "when an explicit attn_mask is set");
508 }
509 return;
510 }
511 // This function is used to produce an attn_mask
512 // in a standard format that can be consumed by both
513 // the math and memory efficient attn_mask implementation
514 // Args:
515 // attn_mask: attn_mask of shape (B, L, S) or (L, S) or (B, N_heads, L, S)
convert_boolean_attn_mask(const std::optional<Tensor> & attn_mask,caffe2::TypeMeta dtype)516 std::optional<Tensor> convert_boolean_attn_mask(const std::optional<Tensor>& attn_mask, caffe2::TypeMeta dtype) {
517 // Pass through
518 if(!attn_mask.has_value()){
519 return c10::nullopt;
520 }
521 // Convert boolean mask to additive mask; need to invert mask to indicate what
522 // to mask *out*.
523 if (attn_mask->dtype() == at::kBool) {
524 auto new_attn_mask = at::zeros_like(attn_mask.value(), dtype);
525 // TODO Use the max type of the input and output
526 new_attn_mask.masked_fill_(
527 attn_mask->logical_not(), -std::numeric_limits<double>::infinity());
528 return new_attn_mask;
529 }
530 // Otherwise, attn_mask represents an additive attention tensor
531 return attn_mask;
532 }
533 // Memory Efficient Attention requires a padded attn mask bias
534 // This function pads the attn_mask bias to be a multiple of 16
535 // Then slices the padded bias to the original size
536 // We apply this function to the top level SDPA so that
537 // if padding is done it will be tracked for backward automatically
538
539 template<int alignment>
aligned_tensor(const at::Tensor & tensor)540 bool aligned_tensor(const at::Tensor& tensor){
541 for(const auto i : c10::irange(tensor.dim() - 1)){
542 if(tensor.sym_stride(i) % alignment != 0){
543 return false;
544 }
545 }
546 return tensor.sym_stride(-1) == 1;
547 }
548
549 template <int alignment>
pad_bias(const at::Tensor & attn_bias)550 at::Tensor pad_bias(const at::Tensor& attn_bias) {
551 auto last_dim_size = attn_bias.sym_size(-1);
552 auto pad_count = alignment - (last_dim_size % alignment);
553 auto padded_bias = at::pad_symint(attn_bias, {c10::SymInt(0), pad_count});
554 return padded_bias.slice_symint(-1, 0, last_dim_size);
555 }
556
preprocess_mask(const at::Tensor & mask,const at::Tensor & query,const at::Tensor & key,const at::Tensor & value)557 at::Tensor preprocess_mask(
558 const at::Tensor& mask,
559 const at::Tensor& query,
560 const at::Tensor& key,
561 const at::Tensor& value) {
562 constexpr int mem_eff_alignment = 8;
563 at::Tensor result_mask = mask;
564 if (!aligned_tensor<mem_eff_alignment>(mask)) {
565 result_mask = pad_bias<mem_eff_alignment>(mask);
566 }
567 return result_mask.expand_symint(
568 {query.sym_size(0),
569 query.sym_size(1),
570 query.sym_size(2),
571 key.sym_size(2)});
572 }
573 // FlashAttentionV2 requires that head dimension be a multiple of 8
574 // This was previously done within the kernel, however
575 // This causes the kernel to maybe alias query, key, value
576 // So instead we pad the head_dimensions to be a multiple of 8 in the composite
577 // region
578 template <int alignment_size, bool slice>
pad_last_dim(const at::Tensor & attn_bias)579 at::Tensor pad_last_dim(const at::Tensor& attn_bias) {
580 auto last_dim_size = attn_bias.sym_size(-1);
581 if (last_dim_size % alignment_size == 0) {
582 return attn_bias;
583 }
584 auto pad_count = alignment_size - (last_dim_size % alignment_size);
585 auto padded_bias = at::pad_symint(attn_bias, {c10::SymInt(0), pad_count});
586 if (slice) {
587 return padded_bias.slice_symint(-1, 0, last_dim_size);
588 }
589 return padded_bias;
590 }
591
post_process_flash_output(at::Tensor out,c10::SymInt const & og_size)592 at::Tensor post_process_flash_output(
593 at::Tensor out,
594 c10::SymInt const& og_size) {
595 if (!out.is_nested() && out.sym_size(-1) != og_size) {
596 out = out.slice_symint(-1, 0, og_size);
597 }
598 return out;
599 }
600
handle_private_use(const Tensor & query_,const Tensor & key,const Tensor & value,const std::optional<Tensor> & attn_mask_,double dropout_p,bool is_causal,std::optional<double> scale)601 int64_t handle_private_use(const Tensor& query_, const Tensor& key, const Tensor& value,
602 const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale){
603 int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math);
604 try {
605 choice_int = _fused_sdp_choice_stub(query_.device().type(),
606 query_, key, value, attn_mask_, dropout_p, is_causal, scale);
607 } catch(const ::c10::Error& e){
608 }
609 return choice_int;
610 }
611
should_compute_logsumexp(const Tensor & query,const Tensor & key,const Tensor & value)612 bool should_compute_logsumexp(const Tensor& query, const Tensor& key, const Tensor& value) {
613 const bool any_inputs_require_grad = query.requires_grad() || key.requires_grad() || value.requires_grad();
614 const bool gradmode_enabled = at::GradMode::is_enabled();
615 return any_inputs_require_grad && gradmode_enabled;
616 }
617
618 } // namespace
619
620 // Computes scaled dot product attention on query, key and value tensors, using
621 // an optional attention mask if passed, and applying dropout if a probability
622 // greater than 0.0 is specified.
623 //
624 // Args:
625 // query (Tensor): Query tensor; shape (N, ..., L, E)
626 // key (Tensor): Key tensor; shape (N, ..., S, E)
627 // value (Tensor): Value tensor; shape (N, ..., S, E)
628 // attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights,
629 // which is (N,..., L, S). Two types of masks are supported.
630 // A boolean mask where a value of True indicates that the element *should* take part in attention.
631 // A float mask of the same type as query, key, value that is added to the attention score.
632 // dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
633 // need_attn_weights (bool): If true, the second return value will contain the attention weights used;
634 // otherwise, the second return value is unspecified
635 // is_causal (bool): If true, assumes causal attention masking; for this case, attn_mask should not be set.
636 // TODO: Consider removing this flag before promoting this function to the public API. It's possible
637 // to get specialized support for causal masks (and other types of masking e.g. local attention / block
638 // sparse masks) via tensor subclassing, allowing for a leaner API.
639 //
640 // Returns a tensor:
641 // output (Tensor): Attention output; shape (N, ..., L, E)
642 //
643 // Shape legend:
644 // N: Batch size
645 // ...: Any number of other batch dimensions (optional)
646 // S: Source sequence length
647 // L: Target sequence length
648 // E: Embedding dimension
scaled_dot_product_attention(const Tensor & query_,const Tensor & key,const Tensor & value,const std::optional<Tensor> & attn_mask_,double dropout_p,bool is_causal,std::optional<double> scale)649 Tensor scaled_dot_product_attention(
650 const Tensor& query_,
651 const Tensor& key,
652 const Tensor& value,
653 const std::optional<Tensor>& attn_mask_,
654 double dropout_p,
655 bool is_causal,
656 std::optional<double> scale) {
657 validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal, scale);
658 int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math);
659 if (_fused_sdp_choice_stub.is_device_supported(query_.device().type())) {
660 choice_int = _fused_sdp_choice_stub(query_.device().type(),
661 query_, key, value, attn_mask_, dropout_p, is_causal, scale);
662 }
663 sdp::SDPBackend backend = static_cast<sdp::SDPBackend>(choice_int);
664 std::optional<Tensor> attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype());
665 switch (backend) {
666 case sdp::SDPBackend::cudnn_attention: {
667 bool compute_logsumexp = should_compute_logsumexp(query_, key, value);
668 auto out_lse_softmax = at::_scaled_dot_product_cudnn_attention(
669 query_, key, value, dropout_p, is_causal, compute_logsumexp, scale);
670 return std::get<0>(out_lse_softmax);
671 }
672 case sdp::SDPBackend::flash_attention: {
673 if(query_.device().type() == DeviceType::CUDA){
674 c10::SymInt og_size = query_.sym_size(-1);
675 Tensor query_padded = pad_last_dim<8, false>(query_);
676 Tensor key_padded = pad_last_dim<8, false>(key);
677 Tensor value_padded = pad_last_dim<8, false>(value);
678 // We need to calculate the scale based off the OG head dim size
679 auto og_scale = sdp::calculate_scale(query_, scale);
680 auto out_lse_softmax = at::_scaled_dot_product_flash_attention(
681 query_padded, key_padded, value_padded, dropout_p, is_causal, false /*return_debug_mask*/, og_scale.as_float_unchecked());
682 return post_process_flash_output(std::get<0>(out_lse_softmax), og_size);
683 }
684 // For the CPU case we do not need to pad the last dim
685 return std::get<0>(at::_scaled_dot_product_flash_attention_for_cpu(
686 query_, key, value, dropout_p, is_causal, attn_mask, scale));
687 }
688 case sdp::SDPBackend::efficient_attention: {
689 bool compute_logsumexp = should_compute_logsumexp(query_, key, value);
690 if (attn_mask.has_value()) {
691 attn_mask.value() = preprocess_mask(attn_mask.value(), query_, key, value);;
692 }
693 auto out_and_lse = at::_scaled_dot_product_efficient_attention(
694 query_, key, value, attn_mask, compute_logsumexp, dropout_p, is_causal, scale);
695 return std::get<0>(out_and_lse);
696 }
697 case sdp::SDPBackend::math:
698 return std::get<0>(at::_scaled_dot_product_attention_math(
699 query_,
700 key,
701 value,
702 attn_mask,
703 dropout_p,
704 is_causal,
705 c10::nullopt, /*dropout_mask*/
706 scale));
707 default:
708 TORCH_CHECK(
709 false,
710 "No viable backend for scaled_dot_product_attention was found.");
711 return Tensor();
712 }
713 }
714
_scaled_dot_product_attention_math(const Tensor & query_,const Tensor & key,const Tensor & value,const std::optional<Tensor> & attn_mask_,double dropout_p,bool is_causal,const std::optional<Tensor> & dropout_mask,std::optional<double> scale)715 std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
716 const Tensor& query_, const Tensor& key, const Tensor& value,
717 const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal,
718 const std::optional<Tensor>& dropout_mask, std::optional<double> scale) {
719 C10_LOG_API_USAGE_ONCE("torch.sdpa.math_fallback");
720 if (query_.is_nested() || key.is_nested() || value.is_nested()) {
721 TORCH_CHECK(
722 query_.is_contiguous() && key.is_contiguous() &&
723 value.is_contiguous(),
724 "scaled_dot_product_attention: If inputs are nested tensors they must be contiguous");
725 }
726 auto attn_mask = attn_mask_;
727 // Naive, composite implementation defined here.
728
729 // Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
730 bool is_negative_scaling = scale.has_value() && scale.value() < 0.0;
731 const auto scaling_factor = sdp::calculate_scale(query_, is_negative_scaling ? std::abs(scale.value()) : scale).sqrt();
732
733 const auto query = query_ * (is_negative_scaling ? c10::SymFloat(0.0) - scaling_factor: scaling_factor);
734 if (is_causal) {
735 TORCH_CHECK(!attn_mask.has_value(),
736 "_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True");
737 TORCH_CHECK(!query.is_nested() && !key.is_nested(),
738 "_scaled_dot_product_attention: Nested tensors for query / key are not supported when is_causal=True");
739
740 // Replace attn_mask with causal mask; lower triangular elements take part in attention.
741 const auto L = query.sym_size(-2), S = key.sym_size(-2);
742 attn_mask = at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril();
743 attn_mask = convert_boolean_attn_mask(attn_mask, query.dtype());
744 }
745 auto attn = at::matmul(query, key.transpose(-2, -1) * scaling_factor);
746 if (attn_mask.has_value()) {
747 if (at::areAnyTensorSubclassLike({attn, *attn_mask})) {
748 attn = attn.add(*attn_mask);
749 } else {
750 attn.add_(*attn_mask);
751 }
752 }
753 attn = at::softmax(attn, -1);
754 if (dropout_p > 0.0) {
755 if (dropout_mask.has_value()) {
756 // In order to validate the correctness of the fused kernels, we need to
757 // use the same dropout mask in order to compare the results.
758 TORCH_WARN_ONCE("Dropout mask should only be used for testing purposes.");
759 attn = attn.masked_fill(dropout_mask->logical_not(), 0.0);
760 auto dropout_scaling = 1.0 / (1 - dropout_p);
761 return std::make_tuple(at::matmul(attn, value * dropout_scaling), attn);
762 } else {
763 attn = at::dropout(attn, dropout_p, true);
764 }
765 }
766
767 return std::make_tuple(at::matmul(attn, value), attn);
768 }
769
770 std::tuple<at::Tensor, at::Tensor>
_scaled_dot_product_flash_attention_cpu(const Tensor & query,const Tensor & key,const Tensor & value,double dropout_p,bool is_causal,const std::optional<Tensor> & attn_mask,std::optional<double> scale)771 _scaled_dot_product_flash_attention_cpu(
772 const Tensor& query,
773 const Tensor& key,
774 const Tensor& value,
775 double dropout_p,
776 bool is_causal,
777 const std::optional<Tensor>& attn_mask,
778 std::optional<double> scale) {
779 const auto dtype = query.scalar_type();
780 int64_t batchSize = query.size(0);
781 int64_t qSize = query.size(2);
782 int64_t num_head = query.size(1);
783 int64_t headSize = query.size(3);
784
785 TORCH_CHECK(c10::isFloatingType(dtype),
786 "scaled_dot_product_attention_flash_attention: Expected data type in FP32, FP64, BF16, FP16, but got ", dtype, " instead.");
787 TORCH_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4,
788 "scaled_dot_product_attention_flash_attention: Accept only 4 dims inputs shape of {B, H, T, K}");
789 TORCH_CHECK(dropout_p == 0.0,
790 "scaled_dot_product_attention_flash_attention: Currently do not support dropout > 0");
791 TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)),
792 "scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size");
793 TORCH_CHECK(!attn_mask.has_value() ||
794 attn_mask.value().scalar_type() == at::kFloat ||
795 dtype == attn_mask.value().scalar_type(),
796 "scaled_dot_product_attention_flash_attention: Attention mask is the same data type as query");
797 TORCH_CHECK(!attn_mask.has_value() ||
798 (attn_mask.value().dim() == 2 || attn_mask.value().dim() == 4),
799 "scaled_dot_product_attention_flash_attention: Attention mask dim in {2, 4}");
800
801 at::Tensor output = at::empty({batchSize, qSize, num_head, headSize}, query.options());
802 const auto accumulate_dtype = toOpMathType(dtype);
803 at::Tensor logsumexp = at::empty({batchSize, qSize, num_head},
804 query.options().dtype(accumulate_dtype));
805
806 flash_attention_kernel(kCPU, output, logsumexp,
807 query, key, value, dropout_p, is_causal, attn_mask, scale);
808
809 output = output.transpose(1, 2);
810 logsumexp = logsumexp.transpose(1, 2);
811
812 return std::make_tuple(std::move(output), std::move(logsumexp));
813 }
814
815 std::tuple<at::Tensor, at::Tensor, at::Tensor>
_scaled_dot_product_flash_attention_cpu_backward(const Tensor & grad_out,const Tensor & query,const Tensor & key,const Tensor & value,const Tensor & out,const Tensor & logsumexp,double dropout_p,bool is_causal,const std::optional<Tensor> & attn_mask,std::optional<double> scale)816 _scaled_dot_product_flash_attention_cpu_backward(
817 const Tensor& grad_out,
818 const Tensor& query,
819 const Tensor& key,
820 const Tensor& value,
821 const Tensor& out,
822 const Tensor& logsumexp,
823 double dropout_p,
824 bool is_causal,
825 const std::optional<Tensor>& attn_mask,
826 std::optional<double> scale) {
827 if (!grad_out.defined()) {
828 return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
829 }
830 auto grad_out_t = grad_out.transpose(1, 2);
831 auto q_t = query.transpose(1, 2);
832 auto k_t = key.transpose(1, 2);
833 auto v_t = value.transpose(1, 2);
834 auto o_t = out.transpose(1, 2);
835 auto lse_t = logsumexp.transpose(1, 2);
836
837 auto grad_q = at::zeros(q_t.sizes(), query.options());
838 auto grad_k = at::zeros(k_t.sizes(), key.options());
839 auto grad_v = at::zeros(v_t.sizes(), value.options());
840
841 flash_attention_backward_kernel(kCPU, grad_q, grad_k, grad_v,
842 grad_out_t, q_t, k_t, v_t, o_t, lse_t,
843 dropout_p, is_causal, attn_mask, scale);
844
845 grad_q = grad_q.transpose(1, 2);
846 grad_k = grad_k.transpose(1, 2);
847 grad_v = grad_v.transpose(1, 2);
848
849 return std::make_tuple(std::move(grad_q), std::move(grad_k), std::move(grad_v));
850 }
851
triton_multi_head_attention(const Tensor & query,const Tensor & key,const Tensor & value,const int64_t embed_dim,const int64_t num_head,const Tensor & qkv_weight,const Tensor & qkv_bias,const Tensor & proj_weight,const Tensor & proj_bias,const std::optional<Tensor> & mask)852 Tensor triton_multi_head_attention(
853 const Tensor& query,
854 const Tensor& key,
855 const Tensor& value,
856 const int64_t embed_dim,
857 const int64_t num_head,
858 const Tensor& qkv_weight,
859 const Tensor& qkv_bias,
860 const Tensor& proj_weight,
861 const Tensor& proj_bias,
862 const std::optional<Tensor>& mask) {
863 // query shape: [B, T, D]
864 // qkv_weight shape: [3 * D, D]
865 TORCH_CHECK(!mask, "Only causal mask is supported for Triton.");
866
867 const auto D = embed_dim;
868 TORCH_CHECK(
869 query.dim() == 3,
870 "expected 3-D `query`, got ",
871 query.dim(),
872 "-D tensor");
873 TORCH_CHECK(
874 query.sizes()[2] == embed_dim,
875 "passed-in embed_dim ",
876 embed_dim,
877 " didn't match last dim of query ",
878 query.sizes()[2]);
879 TORCH_CHECK(
880 key.dim() == 3,
881 "expected 3-D `key`, got ",
882 key.dim(),
883 "-D tensor");
884 TORCH_CHECK(
885 value.dim() == 3,
886 "expected 3-D `value`, got ",
887 value.dim(),
888 "-D tensor");
889 TORCH_CHECK(
890 query.sizes() == key.sizes() && key.sizes() == value.sizes(),
891 "expected `query`/`key`/`value` shapes to match");
892 TORCH_CHECK(
893 qkv_weight.dim() == 2,
894 "expected 2-D `qkv_weight`, got ",
895 qkv_weight.dim(),
896 "-D tensor");
897 TORCH_CHECK(
898 D * 3 == qkv_weight.sizes()[0],
899 "expected `qkv_weight` first dim to be 3x embed_dim");
900 TORCH_CHECK(
901 D == qkv_weight.sizes()[1],
902 "expected `qkv_weight` second dim to be embed_Dim");
903
904 #ifndef NDEBUG
905 const auto B = query.is_nested()
906 ? get_nested_tensor_impl(query)->get_nested_sizes().size(0)
907 : query.sizes()[0];
908 auto T = query.is_nested() ? 0 : query.sizes()[1];
909 const auto dim_per_head = D / num_head;
910 #endif
911
912 // shape: [B, T, 3 x D]
913 auto qkv = qkv_projection(query, key, value, embed_dim, qkv_weight);
914
915 // shape: 3 x [B, num_head, T, dim_per_head]
916 auto q_k_v = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
917 qkv = Tensor(); // Not used any more, allow free
918 auto& q = std::get<0>(q_k_v);
919 const auto& k = std::get<1>(q_k_v);
920 const auto& v = std::get<2>(q_k_v);
921 #ifndef NDEBUG
922 debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head});
923 debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head});
924 debug_assert_shape(__LINE__, v, {B, num_head, T, dim_per_head});
925 #endif
926 #ifdef DEBUG_PRINT_EACH_STEP
927 std::cerr << "q: " << q << std::endl;
928 std::cerr << "k: " << k << std::endl;
929 std::cerr << "v: " << v << std::endl;
930 #endif
931
932 auto attn_ctx = at::_triton_scaled_dot_attention(q, k, v);
933
934 #ifndef NDEBUG
935 debug_assert_shape(__LINE__, attn_ctx, {B, num_head, T, dim_per_head});
936 #endif
937 #ifdef DEBUG_PRINT_EACH_STEP
938 std::cerr << "attn_ctx: " << attn_ctx << std::endl;
939 #endif
940
941 // shape: [B, T, D]
942 // Fuse transform_0213 inside
943 auto proj = transform0213_gemm_nt_bias(
944 attn_ctx, proj_weight, proj_bias, query);
945 #ifndef NDEBUG
946 debug_assert_shape(__LINE__, proj, {B, T, D});
947 #endif
948 return proj;
949 }
950 } // namespace native
951 } // namespace at
952