1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/llm/custom_ops/op_sdpa.h>
10
11 #include <executorch/kernels/optimized/blas/CPUBlas.h>
12 #include <executorch/kernels/optimized/vec/functional.h>
13 #include <executorch/kernels/optimized/vec/vec.h>
14 #include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
15 // @lint-ignore CLANGTIDY facebook-unused-include-check
16 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
17
18 #include <array>
19 #include <vector>
20
21 #ifdef ET_USE_THREADPOOL
22 #include <executorch/extension/parallel/thread_parallel.h>
23 #include <executorch/extension/threadpool/threadpool.h>
24 #endif
25 #include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
26
27 namespace torch {
28 namespace executor {
29
30 namespace native {
31
32 namespace util {
33
34 constexpr size_t kKVDim = 4;
35
36 template <typename T>
_store(T * dst,::executorch::vec::Vectorized<T> src)37 inline void _store(T* dst, ::executorch::vec::Vectorized<T> src) {
38 src.store(dst);
39 }
40
41 /*
42 inline void _store(::Half* dst, at::vec::Vectorized<float> src) {
43 //fp16_ieee_to_fp32_value
44 auto res = at::vec::convert_float_half(src, src);
45 res.store(dst, at::vec::Vectorized<float>::size());
46 }
47 */
48
49 template <typename T>
data_index_init(T offset)50 inline T data_index_init(T offset) {
51 return offset;
52 }
53
54 template <typename T, typename... Args>
data_index_init(T offset,T & x,const T & X,Args &&...args)55 inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
56 offset = data_index_init(offset, std::forward<Args>(args)...);
57 x = offset % X;
58 return offset / X;
59 }
60
data_index_step()61 inline bool data_index_step() {
62 return true;
63 }
64
65 template <typename T, typename... Args>
data_index_step(T & x,const T & X,Args &&...args)66 inline bool data_index_step(T& x, const T& X, Args&&... args) {
67 if (data_index_step(std::forward<Args>(args)...)) {
68 x = ((x + 1) == X) ? 0 : (x + 1);
69 return x == 0;
70 }
71 return false;
72 }
73
calculate_scale(const Tensor & query,optional<double> scale)74 inline double calculate_scale(const Tensor& query, optional<double> scale) {
75 const auto softmax_scale =
76 scale.has_value() ? scale.value() : 1.0 / std::sqrt(query.size(3));
77 return softmax_scale;
78 }
79
80 } // namespace util
81 namespace vec = ::executorch::vec;
82 using Tensor = exec_aten::Tensor;
83
84 namespace {
85
86 // 1) out = exp(a - val)
87 // 2) val = sum(out)
88 template <typename T1, typename T2>
89 inline void
_exp_reduce_sum_fusion_kernel(T1 * a,const int & size,T2 * out,T1 & val)90 _exp_reduce_sum_fusion_kernel(T1* a, const int& size, T2* out, T1& val) {
91 auto vec_size = vec::Vectorized<T1>::size();
92 auto vec_max = vec::Vectorized<T1>(val);
93 T1 tmp_sum = 0;
94 auto vec_tmp_sum = vec::Vectorized<T1>(tmp_sum);
95 for (int i = 0; i < vec_size * (size / vec_size); i += vec_size) {
96 auto tmp0 = vec::Vectorized<T1>::loadu(a + i);
97 auto tmp1 = tmp0 - vec_max;
98 // Replace with exp_u20 later
99 // auto tmp2 = tmp1.exp_u20();
100 auto tmp2 = tmp1.exp();
101 vec_tmp_sum += tmp2;
102 util::_store(out + i, tmp2);
103 }
104 tmp_sum = vec::vec_reduce_all<T1>(
105 [](vec::Vectorized<T1>& x, vec::Vectorized<T1>& y) { return x + y; },
106 vec_tmp_sum);
107 for (int i = vec_size * (size / vec_size); i < size; i++) {
108 auto tmp0 = a[i];
109 auto tmp1 = tmp0 - val;
110 auto tmp2 = exp(tmp1);
111 tmp_sum += tmp2;
112 out[i] = tmp2;
113 }
114 val = tmp_sum;
115 }
116
117 // 1) out = a * scale
118 // 2) max = max(out)
119 template <typename scalar_t>
_mul_reduce_max_fusion_kernel(const scalar_t * a,const scalar_t & scale,const int & size,scalar_t * out,scalar_t & max)120 inline void _mul_reduce_max_fusion_kernel(
121 const scalar_t* a,
122 const scalar_t& scale,
123 const int& size,
124 scalar_t* out,
125 scalar_t& max) {
126 auto vec_size = vec::Vectorized<scalar_t>::size();
127 auto vec_scale = vec::Vectorized<scalar_t>(scale);
128 scalar_t tmp_max = -std::numeric_limits<scalar_t>::infinity();
129 auto vec_tmp_max = vec::Vectorized<scalar_t>(tmp_max);
130 for (int i = 0; i < vec_size * (size / vec_size); i += vec_size) {
131 auto tmp0 = vec::Vectorized<scalar_t>::loadu(a + i);
132 auto tmp1 = tmp0 * vec_scale;
133 vec_tmp_max = vec::maximum(vec_tmp_max, tmp1);
134 util::_store(out + i, tmp1);
135 }
136 for (int i = vec_size * (size / vec_size); i < size; i++) {
137 auto tmp0 = a[i];
138 auto tmp1 = tmp0 * scale;
139 tmp_max = std::max(tmp_max, tmp1);
140 out[i] = tmp1;
141 }
142 max = std::max(
143 tmp_max,
144 vec::vec_reduce_all<scalar_t>(
145 [](vec::Vectorized<scalar_t>& x, vec::Vectorized<scalar_t>& y) {
146 return vec::maximum(x, y);
147 },
148 vec_tmp_max));
149 }
150
151 template <typename scalar_t>
conditional_data_ptr(scalar_t * ptr,scalar_t * ptr2)152 static inline scalar_t* conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) {
153 ET_CHECK(ptr2 == nullptr);
154 return ptr;
155 }
156
157 template <
158 typename scalar_t,
159 typename std::enable_if_t<
160 ::executorch::runtime::is_reduced_floating_point_v<scalar_t>,
161 int> = 0>
conditional_data_ptr(float * ptr,scalar_t * ptr2)162 static inline scalar_t* conditional_data_ptr(float* ptr, scalar_t* ptr2) {
163 (void)ptr;
164 return ptr2;
165 }
166
167 template <typename scalar_t>
fill_stub(scalar_t * data,scalar_t val,int64_t size)168 inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) {
169 using Vec = vec::Vectorized<scalar_t>;
170 Vec data_vec = Vec(val);
171 int64_t d = 0;
172 for (; d < size - (size % Vec::size()); d += Vec::size()) {
173 data_vec.store(data + d);
174 }
175 for (; d < size; d++) {
176 data[d] = val;
177 }
178 }
179
180 /*
181 Note on start_pos as a parameter:
182 What is start_pos?
183 - start_pos is the position of the first element of the current query. That is,
184 in LLMs during generate phase, when we generate one token a time, the query
185 will correspond to monotonically increasing start_pos. e.g. the first token
186 is at start_pos = 0, the second token is at start_pos = 1, and so on.
187 If we do prefill with prompt which has 4 tokens, then during the decode phase,
188 start_pos = 4.
189
190 Why is start_pos neded?
191 - Attention should not need to know start_pos. However, to apply causal mask,
192 we can use is_causal parameter (aten API for SDPA is thinking of getting rid
193 of it). However, the current handling of is_causal assumes that start_pos = 0.
194 Meaning when we have a query during decode at start_pos = 4, it will be a
195 single vector of [1, head_dim] for a given head. Key param, derived from kv
196 cache, will be of size [start_pos + 1, head_dim]. That is all the past tokens
197 contained in kv cache. If we apply causal mask naively, then the query is
198 assumed to be at start_pos = 0, and thus all the future tokens (indices 1...4)
199 in q @ k.T = [1, start_pos], will be masked out for attention calculation.
200 However, that is not right. Since query is at pos 4, that is 4th token, it
201 should attend to all previous tokens in the cache. That is 0...start_pos. Thus
202 we need to pass start_pos.
203
204 Can we use attn_mask?
205 - Yes. Attention mask can be used for the same, however, at the moment attention
206 mask for our llama model is a boolean mask which requires conversion to -inf for
207 masked out section. This requires change that may have perf implication, however
208 we havent really validated this. It is possible that there is no perf
209 implication. If the mask was float mask, thing will work out-of-the-box. In our
210 llama definition each layer is storying mask and if we move to float mask, that
211 can increase memory footprint, which is right now optimized away since
212 sdpa_with_kv_cache does not use attn_mask.
213
214 TODO: Just handle conversion of bool mask to float
215 */
216 template <typename scalar_t, int64_t q_split_size, int64_t kv_split_size>
cpu_flash_attention(Tensor & output,const Tensor & query,const Tensor & key,const Tensor & value,double dropout_p,bool is_causal,const optional<Tensor> & attn_mask,const optional<double> & scale,bool is_seq_at_dim_1=false,const int64_t start_pos=0)217 void cpu_flash_attention(
218 Tensor& output,
219 const Tensor& query,
220 const Tensor& key,
221 const Tensor& value,
222 double dropout_p,
223 bool is_causal,
224 const optional<Tensor>& attn_mask,
225 const optional<double>& scale,
226 bool is_seq_at_dim_1 = false,
227 const int64_t start_pos = 0) {
228 (void)dropout_p;
229 // Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
230 // Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
231 // Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
232
233 /*
234 // -> (Batch x Q_seq_len x Num_heads x Dim_per_head)
235 at::Tensor query = q.transpose(1, 2);
236 // -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
237 at::Tensor key = k.transpose(1, 2);
238 // -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
239 at::Tensor value = v.transpose(1, 2);
240 */
241
242 // Without this we have out-of-bounds writes for
243 // causal masking
244 static_assert(
245 kv_split_size > q_split_size,
246 "KV_split_size must be greater than q_split_size");
247
248 constexpr bool is_reduced_type =
249 ::executorch::runtime::is_reduced_floating_point_v<scalar_t>;
250
251 ET_CHECK_MSG(
252 !is_reduced_type, "FlashAttention does not support reduced types.");
253 // Figure out mixed precision a little later
254 // using accum_t = at::opmath_type<scalar_t>;
255 using accum_t = scalar_t;
256 using Vec = vec::Vectorized<accum_t>;
257 accum_t scaling_factor =
258 static_cast<accum_t>(util::calculate_scale(query, scale));
259
260 int64_t batchSize = query.size(0);
261 int64_t num_head = query.size(1);
262 int64_t qSize = query.size(2);
263 int64_t headSize = query.size(3);
264 int64_t kvSize = value.size(2);
265 int64_t num_heads_kv = key.size(1);
266
267 if (is_seq_at_dim_1) {
268 num_head = query.size(2);
269 num_heads_kv = key.size(2);
270 qSize = query.size(1);
271 kvSize = value.size(1);
272 }
273
274 ET_CHECK_MSG(
275 num_heads_kv <= num_head,
276 "FlashAttention does not support num kv heads > num query heads.Got num query heads=%" PRId64
277 " num key heads:%" PRId64,
278 num_head,
279 num_heads_kv);
280 ET_CHECK_MSG(
281 num_head % num_heads_kv == 0,
282 "FlashAttention: num qyery heads must be divisible by num kv heads but got num query heads=%" PRId64
283 " and num kv heads=%" PRId64,
284 num_head,
285 num_heads_kv);
286 int64_t num_reps = num_head / num_heads_kv;
287
288 bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel();
289 if (has_attn_mask) {
290 /*
291 TODO: fix this for upcasting attn mask
292 if (is_reduced_type) {
293 // SHould not come here for now.
294 attn_mask.value() = attn_mask.value().to(at::kFloat);
295 }
296 */
297 ET_CHECK_MSG(attn_mask.value().dim() == 2, "attn_mask must be 2D");
298 ET_CHECK_MSG(
299 attn_mask.value().size(0) == qSize, "attn_mask shape mismatch");
300 ET_CHECK_MSG(
301 attn_mask.value().size(1) == kvSize,
302 "attn_mask shape mismatch"
303 "attn_mask.size(1)=%zd kvSize=%" PRId64,
304 attn_mask.value().size(1),
305 kvSize);
306 }
307
308 auto strides = query.strides();
309 int64_t qStrideB = strides[0];
310 int64_t qStrideH = strides[1];
311 int64_t qStrideM = strides[2];
312
313 if (is_seq_at_dim_1) {
314 qStrideH = strides[2];
315 qStrideM = strides[1];
316 }
317
318 strides = key.strides();
319 int64_t kStrideB = strides[0];
320 int64_t kStrideH = strides[1];
321 int64_t kStrideN = strides[2];
322
323 if (is_seq_at_dim_1) {
324 kStrideH = strides[2];
325 kStrideN = strides[1];
326 }
327
328 strides = value.strides();
329 int64_t vStrideB = strides[0];
330 int64_t vStrideH = strides[1];
331 int64_t vStrideN = strides[2];
332
333 if (is_seq_at_dim_1) {
334 vStrideH = strides[2];
335 vStrideN = strides[1];
336 }
337
338 strides = output.strides();
339 int64_t oStrideB = strides[0];
340 int64_t oStrideH = strides[1];
341 int64_t oStrideM = strides[2];
342
343 if (is_seq_at_dim_1) {
344 oStrideH = strides[2];
345 oStrideM = strides[1];
346 }
347
348 int64_t mStrideB = 0;
349 int64_t mStrideH = 0;
350 int64_t mStrideM = 0;
351 if (has_attn_mask) {
352 // int64_t mStrideB = 0;
353 //(has_attn_mask && attn_mask.value().size(0) > 1)
354 // ? attn_mask.value().stride(0)
355 // : 0;
356 // int64_t mStrideH = 0;
357 //(has_attn_mask && attn_mask.value().size(1) > 1)
358 // ? attn_mask.value().stride(1)
359 // : 0;
360 strides = attn_mask.value().strides();
361 mStrideM = strides[0];
362 }
363
364 int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
365 int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
366 int64_t qSlice = (qSize - 1) / qSplitSize + 1;
367 #ifdef ET_USE_THREADPOOL
368 int64_t num_thread =
369 ::executorch::extension::threadpool::get_threadpool()->get_thread_count();
370 #else
371 int64_t num_thread = 1;
372 #endif
373
374 // const auto dtype = query.scalar_type();
375 // Following will be revisited in the future
376 // const auto accumulate_dtype = dtype; // toOpMathType(dtype);
377
378 // allocate per thread temp buf (accumulate type)
379 int64_t size_per_thread =
380 /* qk */ qSplitSize * kvSplitSize +
381 /* qk_max */ qSplitSize +
382 /* qk_sum */ qSplitSize +
383 /* dst */ qSplitSize * headSize;
384
385 int64_t size_bytes = size_per_thread * num_thread * query.element_size();
386 std::vector<char> buf_vec(size_bytes);
387 void* buf = reinterpret_cast<void*>(buf_vec.data());
388 // Need to double check the following
389 size_bytes = num_thread * qSplitSize * kvSplitSize * query.element_size();
390 std::vector<char> buf_reduced_vec(size_bytes);
391 void* buf_reduced = reinterpret_cast<void*>(buf_reduced_vec.data());
392 // at::Tensor buf_reduced = at::empty(
393 // {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0},
394 // query.options());
395
396 // Data ptrs
397 const scalar_t* q_data = query.const_data_ptr<scalar_t>();
398 const scalar_t* k_data = key.const_data_ptr<scalar_t>();
399 const scalar_t* v_data = value.const_data_ptr<scalar_t>();
400 const accum_t* mask_data =
401 has_attn_mask ? attn_mask.value().const_data_ptr<accum_t>() : nullptr;
402 scalar_t* out_data = output.mutable_data_ptr<scalar_t>();
403 accum_t* buf_data = reinterpret_cast<accum_t*>(buf);
404 scalar_t* buf_reduced_data =
405 is_reduced_type ? reinterpret_cast<scalar_t*>(buf_reduced) : nullptr;
406
407 auto compute_lambda = [&](int64_t begin, int64_t end) {
408 int64_t i = 0, j = 0, k = 0;
409 util::data_index_init(begin, i, batchSize, j, num_head, k, qSlice);
410 int ompIdx = torch::executor::get_thread_num();
411 accum_t* buf_ptr = buf_data + ompIdx * size_per_thread;
412 accum_t* qk_data = buf_ptr;
413 accum_t* qk_max_data = qk_data + qSplitSize * kvSplitSize;
414 accum_t* qk_sum_data = qk_max_data + qSplitSize;
415 accum_t* dst_data = qk_sum_data + qSplitSize;
416 scalar_t* qk_reduced_data = is_reduced_type
417 ? buf_reduced_data + ompIdx * qSplitSize * kvSplitSize
418 : nullptr;
419
420 for (int64_t z = begin; z < end; z++) {
421 int64_t m = k * qSplitSize;
422 int64_t qBlockSize = std::min(qSplitSize, qSize - m);
423 // Initialize max and sum
424 fill_stub(
425 qk_max_data, -std::numeric_limits<accum_t>::infinity(), qBlockSize);
426 // Original flash sdpa wasnt really meant to be used
427 // for decode the way we are using via start_pos here.
428 // Thus when num_keys is 1 during decode phase, we
429 // still need to iterate through all the kv_splits
430 // Take start_pos = 130 and k_split_size = 128
431 // Here we have to produce [1x130] of q @ k.T
432 // when seq_len = 1
433 // But if num_keys = 1 then we dont really loop over
434 // all kv_splits.
435 // When k_split_size > 130, this is not an issue because
436 // there is only one iteration of the following loop anyway.
437 // Outside of determining how many loop iterations are needed
438 // num_keys participates only in causal attention.
439 // Rest of the calculation of q @ k.T and @ v.T is same.
440 // We dont run into this bug when k_split_size < start_pos + seqlen
441 // since there is only one iteration and that applies
442 // causal attention correctly.
443 // Howeve when k_split_size > start_pos + seqlen, we have
444 // more than one iteration, however if we dont adjust num_keys
445 // we dont get more than one iteration
446 // This is unique to this deployment of flash attention since
447 // original implementation wasnt deployed on this way.
448
449 // Some of these bugs can be resolved by relying on attention mask
450 // but that requires storing attention mask in float as the current
451 // code doesnt support bool attention mask.
452 // However, lets just fix that as well.
453 int64_t num_keys =
454 is_causal ? std::min(m + start_pos + qBlockSize, kvSize) : kvSize;
455 auto j_kv = j / num_reps;
456 for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
457 int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
458 // Calculate scale * q @ k.T
459 fill_stub(qk_data, static_cast<accum_t>(0), qSplitSize * kvSplitSize);
460 ::executorch::cpublas::gemm(
461 ::executorch::cpublas::TransposeType::Transpose,
462 ::executorch::cpublas::TransposeType::NoTranspose,
463 kvBlockSize,
464 qBlockSize,
465 headSize,
466 static_cast<accum_t>(1),
467 k_data + i * kStrideB + j_kv * kStrideH + n * kStrideN,
468 kStrideN,
469 q_data + i * qStrideB + j * qStrideH + m * qStrideM,
470 qStrideM,
471 static_cast<accum_t>(0),
472 qk_data,
473 kvBlockSize);
474 // Apply causal mask, fill unused, i.e. future values, with -inf
475 // Say you have q @ k.T size = [16, 32]
476 // With qblock size = 4, say you are processing
477 // q seq len dim = 8:11.
478 // Say kvSplitSize = 4
479 // Then for causal mask, the entries that needs to be
480 // ignored are
481 // [8, 9:31], [9, 10:31], [10, 10:31], [11, 11:31]
482 // Following condition says that num_keys = 8 + 4 =12
483 // (num_keys - n) <= kvSplitSize
484 // num_keys <= n + kvSplitSize
485 // If n + kvSplitSize is larger than 12, then some
486 // entries need masked out. In our example n = 4
487 // will qualify for that
488 if (is_causal && num_keys - n <= kvSplitSize) {
489 // For this fn to work k_split_size > q_split_size
490 for (int32_t row = 0; row < qBlockSize; ++row) {
491 int64_t last_col = m + (row + start_pos) - n;
492 accum_t* row_ptr = qk_data + row * kvBlockSize;
493 fill_stub(
494 row_ptr + last_col + 1,
495 -std::numeric_limits<accum_t>::infinity(),
496 kvBlockSize - last_col - 1);
497 }
498 }
499 // Update attention weights with attention mask
500 // And apply scaling factor
501 // qk <- qk * scaling + attn_mask
502 if (has_attn_mask) {
503 for (int64_t row = 0; row < qBlockSize; ++row) {
504 vec::map2<accum_t>(
505 [scaling_factor](Vec x, Vec y) {
506 return x * Vec(scaling_factor) + y;
507 },
508 qk_data + row * kvBlockSize,
509 qk_data + row * kvBlockSize,
510 mask_data + i * mStrideB + j * mStrideH + (m + row) * mStrideM +
511 n,
512 kvBlockSize);
513 }
514 }
515 // Update coefficients with Softmax
516 accum_t tmp_max = 0, tmp_sum = 0, exp_tmp = 0;
517 for (int64_t row = 0; row < qBlockSize; ++row) {
518 if (has_attn_mask) {
519 // max per row
520 tmp_max = vec::reduce_all<accum_t>(
521 [](Vec& x, Vec& y) { return vec::maximum(x, y); },
522 qk_data + row * kvBlockSize,
523 kvBlockSize);
524 } else {
525 // apply scaling factor and max per row in fusion
526 _mul_reduce_max_fusion_kernel(
527 qk_data + row * kvBlockSize,
528 scaling_factor,
529 kvBlockSize,
530 qk_data + row * kvBlockSize,
531 tmp_max);
532 }
533 tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max;
534 // qk <- exp(qk - max) and sum per row
535 tmp_sum = tmp_max;
536 _exp_reduce_sum_fusion_kernel(
537 qk_data + row * kvBlockSize,
538 kvBlockSize,
539 conditional_data_ptr(qk_data, qk_reduced_data) +
540 row * kvBlockSize,
541 tmp_sum);
542 // exp_tmp <- exp(max[row] - max)
543 exp_tmp = std::exp(qk_max_data[row] - tmp_max);
544 // sum[row] <- sum + exp_tmp * sum[row]
545 qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row];
546 // max[row] <- max
547 qk_max_data[row] = tmp_max;
548 // dst <- dst * exp_tmp
549 if (n > 0) {
550 vec::map<accum_t>(
551 [exp_tmp](Vec x) { return x * Vec(exp_tmp); },
552 dst_data + row * headSize,
553 dst_data + row * headSize,
554 headSize);
555 }
556 }
557 // Calculate Softmax(q @ k.T) @ v
558 ::executorch::cpublas::gemm(
559 ::executorch::cpublas::TransposeType::NoTranspose,
560 ::executorch::cpublas::TransposeType::NoTranspose,
561 headSize,
562 qBlockSize,
563 kvBlockSize,
564 static_cast<accum_t>(1),
565 v_data + i * vStrideB + j_kv * vStrideH + n * vStrideN,
566 vStrideN,
567 conditional_data_ptr(qk_data, qk_reduced_data),
568 kvBlockSize,
569 n == 0 ? static_cast<accum_t>(0) : static_cast<accum_t>(1),
570 dst_data,
571 headSize);
572 }
573 // dst <- dst / sum[row]
574 // reorder MHA output with strides
575 for (int64_t row = 0; row < qBlockSize; ++row) {
576 accum_t sum_reciprocal = 1 / qk_sum_data[row];
577 vec::map<scalar_t>(
578 [sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); },
579 out_data + i * oStrideB + j * oStrideH + m * oStrideM +
580 row * oStrideM,
581 dst_data + row * headSize,
582 headSize);
583 }
584 // Move to the next query
585 util::data_index_step(i, batchSize, j, num_head, k, qSlice);
586 }
587 };
588 torch::executor::parallel_for(
589 0, batchSize * num_head * qSlice, 1, compute_lambda);
590 }
591
validate_flash_attention_args(const Tensor & query,const Tensor & key,const Tensor & value,const optional<Tensor> & attn_mask)592 bool validate_flash_attention_args(
593 const Tensor& query,
594 const Tensor& key,
595 const Tensor& value,
596 const optional<Tensor>& attn_mask) {
597 ET_LOG_MSG_AND_RETURN_IF_FALSE(query.dim() == 4, "query must be a 4D tensor");
598 ET_LOG_MSG_AND_RETURN_IF_FALSE(key.dim() == 4, "key must be a 4D tensor");
599 ET_LOG_MSG_AND_RETURN_IF_FALSE(value.dim() == 4, "value must be a 4D tensor");
600
601 // Sizes
602 ET_LOG_MSG_AND_RETURN_IF_FALSE(
603 (query.size(3) == value.size(3)) && (key.size(3) == value.size(3)),
604 "scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size");
605
606 ET_LOG_MSG_AND_RETURN_IF_FALSE(
607 (query.scalar_type() == ScalarType::Float), "Query must be Float type");
608
609 ET_LOG_MSG_AND_RETURN_IF_FALSE(
610 (query.scalar_type() == key.scalar_type()) &&
611 (query.scalar_type() == value.scalar_type()),
612 "Key and Value must have the same data type as Query");
613
614 ET_LOG_MSG_AND_RETURN_IF_FALSE(
615 !attn_mask.has_value() || attn_mask.value().dim() == 2,
616 "Attention mask must be a 2D tensor");
617
618 ET_LOG_MSG_AND_RETURN_IF_FALSE(
619 !attn_mask.has_value() ||
620 attn_mask.value().scalar_type() == query.scalar_type(),
621 "Attention mask must be a 2D tensor");
622
623 ET_LOG_MSG_AND_RETURN_IF_FALSE(
624 is_contiguous_dim_order(query.dim_order().data(), query.dim()),
625 "key cache must be in contiguous dim order");
626
627 ET_LOG_MSG_AND_RETURN_IF_FALSE(
628 is_contiguous_dim_order(key.dim_order().data(), key.dim()),
629 "value cache must be in contiguous dim order");
630
631 ET_LOG_MSG_AND_RETURN_IF_FALSE(
632 is_contiguous_dim_order(value.dim_order().data(), value.dim()),
633 "value cache must be in contiguous dim order");
634
635 if (attn_mask.has_value()) {
636 ET_LOG_MSG_AND_RETURN_IF_FALSE(
637 is_contiguous_dim_order(
638 attn_mask.value().dim_order().data(), attn_mask.value().dim()),
639 "value cache must be in contiguous dim order");
640 }
641
642 return true;
643 }
644
validate_cache_params(const Tensor & k_cache,const Tensor & v_cache,int64_t start_pos,int64_t seq_length)645 bool validate_cache_params(
646 const Tensor& k_cache,
647 const Tensor& v_cache,
648 int64_t start_pos,
649 int64_t seq_length) {
650 ET_LOG_MSG_AND_RETURN_IF_FALSE(
651 k_cache.dim() == 4, "kcache must be a 4D tensor");
652
653 ET_LOG_MSG_AND_RETURN_IF_FALSE(
654 v_cache.dim() == 4, "v_cache must be a 4D tensor");
655
656 ET_LOG_MSG_AND_RETURN_IF_FALSE(
657 start_pos < k_cache.size(1),
658 "start_pos must be less than key cache at dim 1");
659
660 ET_LOG_MSG_AND_RETURN_IF_FALSE(
661 start_pos < v_cache.size(1),
662 "start_pos must be less than value cache at dim 1");
663
664 ET_LOG_MSG_AND_RETURN_IF_FALSE(
665 (start_pos + seq_length) <= k_cache.size(1),
666 "start_post + seq_length must be less than max seq length supported by key cache."
667 "start pos: %" PRId64 ", seq_length: %" PRId64
668 "."
669 "key cache size: %zd",
670 start_pos,
671 seq_length,
672 k_cache.size(1));
673
674 ET_LOG_MSG_AND_RETURN_IF_FALSE(
675 (start_pos + seq_length) <= v_cache.size(1),
676 "start_post + seq_length must be less than max seq length supported by key cache."
677 "start pos: %" PRId64 ", seq_length: %" PRId64
678 "."
679 "value cache size: %zd",
680 start_pos,
681 seq_length,
682 v_cache.size(1));
683
684 // Make sure they are in contiguous dim order
685 ET_LOG_MSG_AND_RETURN_IF_FALSE(
686 is_contiguous_dim_order(k_cache.dim_order().data(), k_cache.dim()),
687 "key cache must be in contiguous dim order");
688
689 ET_LOG_MSG_AND_RETURN_IF_FALSE(
690 is_contiguous_dim_order(v_cache.dim_order().data(), v_cache.dim()),
691 "value cache must be in contiguous dim order");
692
693 return true;
694 }
695
696 // TODO: seq_length is not yet used for copy
update_cache(const Tensor & projected_value,const Tensor & cache,int64_t start_pos,int64_t seq_length)697 void update_cache(
698 const Tensor& projected_value,
699 const Tensor& cache,
700 int64_t start_pos,
701 int64_t seq_length) { // NOLINT: unused parameter 'seq_length'
702 // 1) Cache shape should be [bs, max_seq_len, num heads, head dim]
703 // 2) projected_value shape should be [bs, seq_len, num heads, head dim]
704 // 3) We're updating the cache with projected_value, at position start_pos
705
706 ET_CHECK_MSG(
707 projected_value.size(0) == cache.size(0),
708 "projected_value batch size should be equal to the cache batch size.");
709 ET_CHECK_MSG(
710 projected_value.size(2) == cache.size(2),
711 "projected_value number of heads should be equal to the cache number of heads.");
712 ET_CHECK_MSG(
713 projected_value.size(3) == cache.size(3),
714 "projected_value embedding dimension should be equal to the cache embedding dimension.");
715 ET_CHECK_MSG(
716 projected_value.element_size() == cache.element_size(),
717 "projected_value data type size should be equal to the cache data type size.");
718
719 ET_CHECK_MSG(
720 is_contiguous_dim_order(
721 projected_value.dim_order().data(), projected_value.dim()),
722 "projected value must be in contiguous dim order");
723 const void* projected_value_data = projected_value.const_data_ptr();
724 void* cache_data = cache.mutable_data_ptr();
725
726 ET_CHECK_MSG(projected_value_data != nullptr, "projected_value data is null");
727 ET_CHECK_MSG(cache_data, "cache data is null");
728
729 auto cache_strides = cache.strides();
730 exec_aten::StridesType cache_batch_dim_stride = cache_strides[0];
731 exec_aten::StridesType cache_seq_dim_stride = cache_strides[1];
732
733 auto value_strides = projected_value.strides();
734 exec_aten::StridesType value_batch_dim_stride = value_strides[0];
735
736 exec_aten::SizesType num_bytes_to_copy =
737 (projected_value.numel() / projected_value.size(0)) *
738 projected_value.element_size();
739
740 for (int64_t batch_line = 0; batch_line < projected_value.size(0);
741 ++batch_line) {
742 exec_aten::SizesType cache_pos_offset =
743 (batch_line * cache_batch_dim_stride +
744 start_pos * cache_seq_dim_stride) *
745 cache.element_size();
746 exec_aten::SizesType value_pos_offset =
747 (batch_line * value_batch_dim_stride) * cache.element_size();
748
749 std::memcpy(
750 (uint8_t*)cache_data + cache_pos_offset,
751 (uint8_t*)projected_value_data + value_pos_offset,
752 num_bytes_to_copy);
753 }
754 }
755
756 } // anonymous namespace
757
flash_attention_kernel_out(RuntimeContext & ctx,const Tensor & query,const Tensor & key,const Tensor & value,const optional<Tensor> & attn_mask,const double dropout_p,const bool is_causal,const optional<double> scale,Tensor & output)758 Tensor& flash_attention_kernel_out(
759 RuntimeContext& ctx,
760 const Tensor& query,
761 const Tensor& key,
762 const Tensor& value,
763 const optional<Tensor>& attn_mask,
764 const double dropout_p,
765 const bool is_causal,
766 // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
767 const optional<double> scale,
768 Tensor& output) {
769 (void)ctx;
770 ET_KERNEL_CHECK(
771 ctx,
772 validate_flash_attention_args(query, key, value, attn_mask),
773 InvalidArgument,
774 output);
775
776 ET_KERNEL_CHECK(
777 ctx,
778 resize_tensor(output, query.sizes()) == Error::Ok,
779 InvalidArgument,
780 output);
781
782 auto q_seq_len = query.size(2);
783
784 ET_SWITCH_FLOAT_TYPES(
785 query.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
786 // TODO we need to re-evaluate this for ARM CPUs
787 // And there can be many so instead of templatizing
788 // we might consider another appraoch
789 if (q_seq_len >= 768) {
790 cpu_flash_attention<CTYPE, 256, 512>(
791 output,
792 query,
793 key,
794 value,
795 dropout_p,
796 is_causal,
797 attn_mask,
798 scale);
799 } else if (q_seq_len >= 192) {
800 cpu_flash_attention<CTYPE, 64, 512>(
801 output,
802 query,
803 key,
804 value,
805 dropout_p,
806 is_causal,
807 attn_mask,
808 scale);
809 } else {
810 cpu_flash_attention<CTYPE, 32, 512>(
811 output,
812 query,
813 key,
814 value,
815 dropout_p,
816 is_causal,
817 attn_mask,
818 scale);
819 }
820 });
821 return output;
822 }
823
824 /*
825 Input params
826 @param[in] q_projected Projected query with query weights.
827 Format [n_layers, batch size, seq_len, num heads, head dim]
828 @param[in] k_projected Projected query with key weights.
829 Format [n_layers, batch size, seq_len, num heads, head dim]
830 @param[in] v_projected Projected query with value weights.
831 Format [n_layers, batch size, seq_len, num heads, head dim]
832 @param[in] key_cache Cache of previous k_projected.
833 Format [n_layers, batch size, max_seq_len, num heads, head dim]
834 @param[in] key_cache Cache of previous v_projected.
835 Format [n_layers, batch size, max_seq_len, num heads, head dim]
836 ....
837 @param[in] start_pos: sequence position
838 @param[in] seq_len: Seq length. e.g. seq_len dim of q_projected.
839 */
custom_sdpa_out(RuntimeContext & ctx,const Tensor & q,const Tensor & k,const Tensor & v,const int64_t start_pos,const optional<Tensor> & attn_mask,const double dropout_p,const bool is_causal,const optional<double> scale,Tensor & output)840 Tensor& custom_sdpa_out(
841 RuntimeContext& ctx,
842 const Tensor& q,
843 const Tensor& k,
844 const Tensor& v,
845 const int64_t start_pos,
846 const optional<Tensor>& attn_mask,
847 const double dropout_p,
848 const bool is_causal,
849 // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
850 const optional<double> scale,
851 Tensor& output) {
852 ET_KERNEL_CHECK_MSG(
853 ctx,
854 !attn_mask.has_value() || !is_causal,
855 InvalidArgument,
856 output,
857 "attn_mask and is_causal cannot be set at the same time");
858
859 ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor");
860
861 const int64_t seq_len = q.size(1);
862 auto q_seq_len = q.size(1);
863
864 // Refactor the following into create_view util perhaps using
865 // TensorPtr
866 std::array<exec_aten::DimOrderType, util::kKVDim> sliced_key_dim_order{
867 0, 1, 2, 3};
868 std::array<exec_aten::SizesType, util::kKVDim> sliced_key_sizes;
869 sliced_key_sizes[0] = k.size(0);
870 sliced_key_sizes[1] = start_pos + seq_len; // key_cache.size(2);
871 sliced_key_sizes[2] = k.size(2);
872 sliced_key_sizes[3] = k.size(3);
873 std::array<exec_aten::StridesType, util::kKVDim> sliced_key_strides;
874 dim_order_to_stride_nocheck(
875 sliced_key_sizes.data(),
876 sliced_key_dim_order.data(),
877 util::kKVDim,
878 sliced_key_strides.data());
879 // since the cache is sliced, the batch stride needs to stay the same.
880 sliced_key_strides[0] = k.strides()[0];
881 void* key_cache_data = k.mutable_data_ptr();
882 TensorImpl k_impl = TensorImpl(
883 k.scalar_type(),
884 util::kKVDim,
885 sliced_key_sizes.data(),
886 key_cache_data,
887 sliced_key_dim_order.data(),
888 sliced_key_strides.data(),
889 TensorShapeDynamism::STATIC);
890 Tensor sliced_key_cache(&k_impl);
891
892 std::array<exec_aten::DimOrderType, util::kKVDim> sliced_value_dim_order{
893 0, 1, 2, 3};
894 std::array<exec_aten::SizesType, util::kKVDim> sliced_value_sizes;
895 sliced_value_sizes[0] = v.size(0);
896 sliced_value_sizes[1] = start_pos + seq_len; // value_cache.size(2);
897 sliced_value_sizes[2] = v.size(2);
898 sliced_value_sizes[3] = v.size(3);
899 std::array<exec_aten::StridesType, util::kKVDim> sliced_value_strides;
900 dim_order_to_stride_nocheck(
901 sliced_value_sizes.data(),
902 sliced_value_dim_order.data(),
903 util::kKVDim,
904 sliced_value_strides.data());
905 // since the cache is sliced, the batch stride needs to stay the same.
906 sliced_value_strides[0] = v.strides()[0];
907 void* value_cache_data = v.mutable_data_ptr();
908 TensorImpl value_impl = TensorImpl(
909 v.scalar_type(),
910 util::kKVDim,
911 sliced_value_sizes.data(),
912 value_cache_data,
913 sliced_value_dim_order.data(),
914 sliced_value_strides.data(),
915 TensorShapeDynamism::STATIC);
916 Tensor sliced_value_cache(&value_impl);
917
918 ET_KERNEL_CHECK(
919 ctx,
920 resize_tensor(output, q.sizes()) == Error::Ok,
921 InvalidArgument,
922 output);
923
924 // TODO(task): replace the template param selection logic
925 // with whatever apprpriately makes more sense for
926 ET_SWITCH_FLOAT_TYPES(q.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
927 // TODO we need to re-evaluate this for ARM CPUs
928 // And there can be many so instead of templatizing
929 // we might consider another appraoch
930 if (q_seq_len >= 768) {
931 cpu_flash_attention<CTYPE, 256, 512>(
932 output,
933 q,
934 sliced_key_cache,
935 sliced_value_cache,
936 dropout_p,
937 is_causal,
938 attn_mask,
939 scale,
940 true, /* is_seq_at_dim_1 */
941 start_pos);
942 } else if (q_seq_len >= 192) {
943 cpu_flash_attention<CTYPE, 64, 512>(
944 output,
945 q,
946 sliced_key_cache,
947 sliced_value_cache,
948 dropout_p,
949 is_causal,
950 attn_mask,
951 scale,
952 true, /* is_seq_at_dim_1 */
953 start_pos);
954 } else {
955 cpu_flash_attention<CTYPE, 32, 512>(
956 output,
957 q,
958 sliced_key_cache,
959 sliced_value_cache,
960 dropout_p,
961 is_causal,
962 attn_mask,
963 scale,
964 true, /* is_seq_at_dim_1 */
965 start_pos);
966 }
967 });
968 return output;
969 }
970 /*
971 Input params
972 @param[in] q_projected Projected query with query weights.
973 Format [n_layers, batch size, seq_len, num heads, head dim]
974 @param[in] k_projected Projected query with key weights.
975 Format [n_layers, batch size, seq_len, num heads, head dim]
976 @param[in] v_projected Projected query with value weights.
977 Format [n_layers, batch size, seq_len, num heads, head dim]
978 @param[in] key_cache Cache of previous k_projected.
979 Format [n_layers, batch size, max_seq_len, num heads, head dim]
980 @param[in] key_cache Cache of previous v_projected.
981 Format [n_layers, batch size, max_seq_len, num heads, head dim]
982 ....
983 @param[in] start_pos: sequence position
984 @param[in] seq_len: Seq length. e.g. seq_len dim of q_projected.
985 */
sdpa_with_kv_cache_out(KernelRuntimeContext & ctx,const Tensor & q_projected,const Tensor & k_projected,const Tensor & v_projected,Tensor & key_cache,Tensor & value_cache,const int64_t start_pos,const int64_t seq_len,const optional<Tensor> & attn_mask,const double dropout_p,const bool is_causal,const optional<double> scale,Tensor & output)986 Tensor& sdpa_with_kv_cache_out(
987 KernelRuntimeContext& ctx,
988 const Tensor& q_projected,
989 const Tensor& k_projected,
990 const Tensor& v_projected,
991 Tensor& key_cache,
992 Tensor& value_cache,
993 const int64_t start_pos,
994 const int64_t seq_len,
995 const optional<Tensor>& attn_mask,
996 const double dropout_p,
997 const bool is_causal,
998 // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
999 const optional<double> scale,
1000 Tensor& output) {
1001 (void)ctx;
1002 ET_KERNEL_CHECK(
1003 ctx,
1004 validate_cache_params(key_cache, value_cache, start_pos, seq_len),
1005 InvalidArgument,
1006 output);
1007
1008 ET_CHECK_MSG(q_projected.dim() == 4, "query must be a 4D tensor");
1009
1010 update_cache(k_projected, key_cache, start_pos, seq_len);
1011 update_cache(v_projected, value_cache, start_pos, seq_len);
1012
1013 custom_sdpa_out(
1014 ctx,
1015 q_projected,
1016 key_cache,
1017 value_cache,
1018 start_pos,
1019 attn_mask,
1020 dropout_p,
1021 is_causal,
1022 scale,
1023 output);
1024
1025 return output;
1026 }
1027 } // namespace native
1028 } // namespace executor
1029 } // namespace torch
1030
1031 EXECUTORCH_LIBRARY(
1032 llama,
1033 "sdpa_with_kv_cache.out",
1034 torch::executor::native::sdpa_with_kv_cache_out);
1035
1036 EXECUTORCH_LIBRARY(
1037 llama,
1038 "custom_sdpa.out",
1039 torch::executor::native::custom_sdpa_out);
1040