1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8 #pragma once
9
10 #include <ATen/cuda/PhiloxUtils.cuh>
11 #include <c10/util/Exception.h>
12
13 #include <curand_kernel.h>
14 #include <cmath>
15 #include <vector>
16
17 #include <cutlass/bfloat16.h>
18 #include <cutlass/fast_math.h>
19 #include <cutlass/gemm/gemm.h>
20 #include <cutlass/layout/matrix.h>
21 #include <cutlass/layout/vector.h>
22 #include <cutlass/matrix.h>
23 #include <cutlass/numeric_types.h>
24 #include <cutlass/tensor_ref.h>
25
26 #include <cutlass/epilogue/threadblock/default_epilogue_simt.h>
27 #include <cutlass/epilogue/threadblock/default_epilogue_tensor_op.h>
28 #include <cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h>
29
30 #include <cutlass/gemm/device/default_gemm_configuration.h>
31 #include <cutlass/gemm/kernel/default_gemm.h>
32 #include <cutlass/gemm/threadblock/default_mma.h>
33 #include <cutlass/gemm/threadblock/default_mma_core_simt.h>
34 #include <cutlass/gemm/threadblock/default_mma_core_sm70.h>
35 #include <cutlass/gemm/threadblock/default_mma_core_sm75.h>
36 #include <cutlass/gemm/threadblock/default_mma_core_sm80.h>
37 #include <cutlass/gemm/threadblock/threadblock_swizzle.h>
38 #include <cutlass/matrix_shape.h>
39 #include <cutlass/platform/platform.h>
40 #include <cutlass/transform/threadblock/predicated_tile_iterator.h>
41
42 #include <ATen/native/transformers/cuda/mem_eff_attention/debug_utils.h>
43 #include <ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_pipelined.h>
44 #include <ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h>
45
46 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma.h>
47 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/find_default_mma.h>
48 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h>
49 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h>
50 #include <ATen/native/transformers/cuda/mem_eff_attention/transform/tile_smem_loader.h>
51
52 #include <cinttypes>
53
54 using namespace gemm_kernel_utils;
55
56 namespace PyTorchMemEffAttention {
57 namespace {
58 template <typename scalar_t, typename Arch>
getWarpsPerSmFw()59 constexpr int getWarpsPerSmFw() {
60 return (
61 Arch::kMinComputeCapability >= 80 &&
62 !cutlass::platform::is_same<scalar_t, float>::value
63 ? 16
64 : 12);
65 }
atomicMaxFloat(float * addr,float value)66 static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
67 // source: https://stackoverflow.com/a/51549250
68 return !signbit(value)
69 ? __int_as_float(atomicMax((int *)addr, __float_as_int(value)))
70 : __uint_as_float(
71 atomicMin((unsigned int *)addr, __float_as_uint(value)));
72 }
73 } // namespace
74
75 template <
76 // The datatype of Q/K/V
77 typename scalar_t_,
78 // Architecture we are targeting (eg `cutlass::arch::Sm80`)
79 typename ArchTag,
80 // If Q/K/V are correctly aligned in memory and we can run a fast kernel
81 bool isAligned_,
82 int kQueriesPerBlock_,
83 int kKeysPerBlock_,
84 // upperbound on `max(value.shape[-1], query.shape[-1])`
85 int kMaxK_ = (int)cutlass::platform::numeric_limits<uint32_t>::max(),
86 // This is quite slower on V100 for some reason
87 // Set to false if you know at compile-time you will never need dropout
88 bool kSupportsDropout_ = true,
89 bool kSupportsBias_ = true>
90 struct AttentionKernel {
91 enum CustomMaskType {
92 NoCustomMask = 0,
93 CausalFromTopLeft = 1,
94 CausalFromBottomRight = 2,
95 NumCustomMaskTypes,
96 };
97
98 using scalar_t = scalar_t_;
99 using accum_t = float;
100 using lse_scalar_t = float;
101 using output_t = scalar_t;
102 // Accumulator between 2 iterations
103 // Using `accum_t` improves perf on f16 at the cost of
104 // numerical errors
105 using output_accum_t = accum_t;
106 static constexpr bool kSupportsDropout = kSupportsDropout_;
107 static constexpr bool kSupportsBias = kSupportsBias_;
108 static constexpr int kKeysPerBlock = kKeysPerBlock_;
109 static constexpr int kQueriesPerBlock = kQueriesPerBlock_;
110 static constexpr int kMaxK = kMaxK_;
111 static constexpr bool kIsAligned = isAligned_;
112 static constexpr bool kSingleValueIteration = kMaxK <= kKeysPerBlock;
113 static constexpr int32_t kAlignLSE = 32; // block size of backward
114 static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value == 16;
115 static constexpr bool kPreloadV =
116 ArchTag::kMinComputeCapability >= 80 && kIsHalf;
117 static constexpr bool kKeepOutputInRF = kSingleValueIteration;
118 static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
119 !cutlass::platform::is_same<output_accum_t, output_t>::value;
120
121 static_assert(kQueriesPerBlock % 32 == 0, "");
122 static_assert(kKeysPerBlock % 32 == 0, "");
123 static constexpr int kNumWarpsPerBlock =
124 kQueriesPerBlock * kKeysPerBlock / (32 * 32);
125 static constexpr int kWarpSize = 32;
126
127 // Launch bounds
128 static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock;
129 static constexpr int kMinBlocksPerSm =
130 getWarpsPerSmFw<scalar_t, ArchTag>() / kNumWarpsPerBlock;
131
132 struct Params {
133 // Input tensors
134 const scalar_t* query_ptr = nullptr; // [num_queries, num_heads, head_dim]
135 const scalar_t* key_ptr = nullptr; // [num_keys, num_heads, head_dim]
136 const scalar_t* value_ptr = nullptr; // [num_keys, num_heads, head_dim_value]
137 const scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys]
138 const int32_t* seqstart_q_ptr = nullptr;
139 const int32_t* seqstart_k_ptr = nullptr;
140
141 const int32_t* seqlen_k_ptr = nullptr;
142 uint32_t causal_diagonal_offset = 0;
143
144 // Output tensors
145 output_t* output_ptr = nullptr; // [num_queries, num_heads, head_dim_value]
146 // [num_queries, num_heads, head_dim_value]
147 output_accum_t* output_accum_ptr = nullptr;
148 // [num_heads, num_queries] - can be null
149 lse_scalar_t* logsumexp_ptr = nullptr;
150
151 // Sliding window. ignored if == 0
152 int32_t window_size = 0;
153
154 // Scale
155 accum_t scale = 0.0;
156
157 // Dimensions/strides
158 int32_t head_dim = 0;
159 int32_t head_dim_value = 0;
160 int32_t num_queries = 0;
161 int32_t num_keys = 0;
162 int32_t num_keys_absolute = 0;
163
164 uint8_t custom_mask_type = NoCustomMask;
165
166 int32_t q_strideM = 0;
167 int32_t k_strideM = 0;
168 int32_t v_strideM = 0;
169 int32_t bias_strideM = 0;
170
171 int32_t o_strideM = 0;
172
173 // Everything below is only used in `advance_to_block`
174 // and shouldn't use registers
175 int32_t q_strideH = 0;
176 int32_t k_strideH = 0;
177 int32_t v_strideH = 0;
178 int64_t bias_strideH = 0;
179
180 int64_t q_strideB = 0;
181 int64_t k_strideB = 0;
182 int64_t v_strideB = 0;
183 int64_t bias_strideB = 0;
184
185 int32_t num_batches = 0;
186 int32_t num_heads = 0;
187
188 // dropout
189 bool use_dropout = false;
190 unsigned long long dropout_batch_head_rng_offset = 0;
191 float dropout_prob = 0.0f;
192 at::PhiloxCudaState rng_engine_inputs = at::PhiloxCudaState(0, 0);
193 int64_t* extragraph_offset = nullptr;
194 int64_t* seed = nullptr;
195
196 // Moves pointers to what we should process
197 // Returns "false" if there is no work to do
advance_to_blockAttentionKernel::Params198 CUTLASS_DEVICE bool advance_to_block() {
199 auto batch_id = blockIdx.z;
200 auto head_id = blockIdx.y;
201 auto query_start = blockIdx.x * kQueriesPerBlock;
202
203 auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE;
204
205 if (kSupportsDropout) {
206 dropout_batch_head_rng_offset =
207 batch_id * num_heads * num_queries * num_keys +
208 head_id * num_queries * num_keys;
209 }
210
211 int64_t q_start = 0, k_start = 0;
212 // Advance to current batch - in case of different sequence lengths
213 if (seqstart_q_ptr != nullptr) {
214 assert(seqstart_k_ptr != nullptr);
215 seqstart_q_ptr += batch_id;
216
217 q_start = seqstart_q_ptr[0];
218 int64_t q_next_start = seqstart_q_ptr[1];
219 int64_t k_end;
220 seqstart_k_ptr += batch_id;
221
222 if (seqlen_k_ptr) {
223 k_start = seqstart_k_ptr[0];
224 k_end = k_start + seqlen_k_ptr[batch_id];
225 } else {
226 k_start = seqstart_k_ptr[0];
227 k_end = seqstart_k_ptr[1];
228 }
229
230 num_queries = q_next_start - q_start;
231 num_keys = k_end - k_start;
232
233 if (query_start >= num_queries) {
234 return false;
235 }
236 } else {
237 query_ptr += batch_id * q_strideB;
238 key_ptr += batch_id * k_strideB;
239 value_ptr += batch_id * v_strideB;
240 output_ptr += int64_t(batch_id * num_queries) * o_strideM;
241 if (output_accum_ptr != nullptr) {
242 output_accum_ptr +=
243 int64_t(batch_id * num_queries) * (head_dim_value * num_heads);
244 }
245 q_start = 0;
246 k_start = 0;
247 }
248
249 // Advance to the current batch / head / query_start
250 query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH;
251 key_ptr += k_start * k_strideM + head_id * k_strideH;
252
253 value_ptr += k_start * v_strideM + head_id * v_strideH;
254 output_ptr +=
255 int64_t(q_start + query_start) * o_strideM + head_id * head_dim_value;
256
257 if (kSupportsBias && attn_bias_ptr != nullptr) {
258 attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH);
259 }
260 if (output_accum_ptr != nullptr) {
261 output_accum_ptr +=
262 int64_t(q_start + query_start) * (head_dim_value * num_heads) +
263 head_id * head_dim_value;
264 } else {
265 // Accumulate directly in the destination buffer (eg for f32)
266 output_accum_ptr = (accum_t*)output_ptr;
267 }
268
269 if (logsumexp_ptr != nullptr) {
270 // lse[batch_id, head_id, query_start]
271 logsumexp_ptr +=
272 batch_id * lse_dim * num_heads + head_id * lse_dim + query_start;
273 }
274
275 // Custom masking
276 if (custom_mask_type == CausalFromBottomRight) {
277 causal_diagonal_offset = num_keys - num_queries;
278 }
279 // We use num_keys_absolute to index into the rng_state
280 // We need this index to match between forward and backwards
281 num_keys_absolute = num_keys;
282 if (custom_mask_type == CausalFromTopLeft ||
283 custom_mask_type == CausalFromBottomRight) {
284 // the bottom row of the current block is query_start + kQueriesPerBlock
285 // the last active key is then query_start + causal_diagonal_offset +
286 // kQueriesPerBlock so num_keys is the min between actual num_keys and
287 // this to avoid extra computations
288 num_keys = cutlass::fast_min(
289 int32_t(query_start + causal_diagonal_offset + kQueriesPerBlock),
290 num_keys);
291 }
292
293 num_queries -= query_start;
294 num_batches = 0; // no longer used after
295
296 // If num_queries == 1, and there is only one key head we're wasting
297 // 15/16th of tensor core compute In that case :
298 // - we only launch kernels for head_id % kQueriesPerBlock == 0
299 // - we iterate over heads instead of queries (strideM = strideH)
300 if (num_queries == 1 && k_strideH == 0 && v_strideH == 0 &&
301 logsumexp_ptr == nullptr && window_size == 0) {
302 if (head_id % kQueriesPerBlock != 0) {
303 return false;
304 }
305 q_strideM = q_strideH;
306 bias_strideM = bias_strideH;
307 num_queries = num_heads;
308 num_heads = 1; // unused but here for intent
309 // remove causal since n_query = 1
310 // otherwise, offset would change with head !
311 custom_mask_type = NoCustomMask;
312 o_strideM = head_dim_value;
313 }
314
315 // Make sure the compiler knows these variables are the same on all
316 // the threads of the warp.
317 // Only worth doing if they could have been modified above.
318 query_ptr = warp_uniform(query_ptr);
319 key_ptr = warp_uniform(key_ptr);
320 value_ptr = warp_uniform(value_ptr);
321 if (kSupportsBias) {
322 attn_bias_ptr = warp_uniform(attn_bias_ptr);
323 }
324 output_ptr = warp_uniform(output_ptr);
325 output_accum_ptr = warp_uniform(output_accum_ptr);
326 logsumexp_ptr = warp_uniform(logsumexp_ptr);
327 num_queries = warp_uniform(num_queries);
328 num_keys = warp_uniform(num_keys);
329 num_heads = warp_uniform(num_heads);
330 o_strideM = warp_uniform(o_strideM);
331 custom_mask_type = warp_uniform(custom_mask_type);
332 return true;
333 }
334
getBlocksGridAttentionKernel::Params335 __host__ dim3 getBlocksGrid() const {
336 return dim3(
337 ceil_div(num_queries, (int32_t)kQueriesPerBlock),
338 num_heads,
339 num_batches);
340 }
341
getThreadsGridAttentionKernel::Params342 __host__ dim3 getThreadsGrid() const {
343 return dim3(kWarpSize, kNumWarpsPerBlock, 1);
344 }
345 };
346
347 struct MM0 {
348 /*
349 In this first matmul, we compute a block of `Q @ K.T`.
350 While the calculation result is still hot in registers, we update
351 `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value
352 into a shared-memory ("AccumulatorSharedStorage") that is used later as
353 operand A for the second matmul (see MM1)
354 */
355 using GemmType = DefaultGemmType<ArchTag, scalar_t>;
356
357 using OpClass = typename GemmType::OpClass;
358 using DefaultConfig =
359 typename cutlass::gemm::device::DefaultGemmConfiguration<
360 OpClass,
361 ArchTag,
362 scalar_t,
363 scalar_t,
364 scalar_t, // ElementC
365 accum_t // ElementAccumulator
366 >;
367 static constexpr int kAlignmentA =
368 kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment;
369 static constexpr int kAlignmentB =
370 kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
371 using ThreadblockShape = cutlass::gemm::
372 GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
373 using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
374 using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma<
375 scalar_t, // ElementA,
376 cutlass::layout::RowMajor, // LayoutA,
377 kAlignmentA,
378 scalar_t, // ElementB,
379 cutlass::layout::ColumnMajor, // LayoutB,
380 kAlignmentB,
381 accum_t,
382 cutlass::layout::RowMajor, // LayoutC,
383 OpClass,
384 ArchTag, // ArchTag
385 ThreadblockShape, // ThreadblockShape
386 WarpShape, // WarpShape
387 typename GemmType::InstructionShape, // InstructionShape
388 ArchTag::kMinComputeCapability >= 80 && kIsHalf
389 ? 4
390 : DefaultConfig::kStages,
391 typename GemmType::Operator // Operator
392 >::DefaultMma;
393 using MmaCore = typename DefaultMma::MmaCore;
394 using IteratorA = typename DefaultMma::IteratorA;
395 using IteratorB = typename DefaultMma::IteratorB;
396 using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma;
397 using Mma = typename cutlass::platform::conditional<
398 kSingleValueIteration,
399 typename MakeCustomMma<DefaultThreadblockMma, kMaxK>::Mma,
400 DefaultThreadblockMma>::type;
401 using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
402 typename Mma::Operator::IteratorC,
403 accum_t,
404 kWarpSize>::Iterator;
405 static_assert(
406 MmaCore::WarpCount::kM * MmaCore::WarpCount::kN *
407 MmaCore::WarpCount::kK ==
408 kNumWarpsPerBlock,
409 "");
410
411 // used for efficient load of bias tile Bij from global to shared memory
412 using BiasLoader = TileSmemLoader<
413 scalar_t,
414 cutlass::MatrixShape<kQueriesPerBlock, kKeysPerBlock>,
415 MmaCore::kThreads,
416 // input restriction: kv_len has to be a multiple of this value
417 128 / cutlass::sizeof_bits<scalar_t>::value>;
418
419 // Epilogue to store to shared-memory in a format that we can use later for
420 // the second matmul
421 using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
422 typename Mma::Operator::IteratorC,
423 typename Mma::Operator,
424 scalar_t,
425 WarpShape,
426 ThreadblockShape>;
427 using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
428 };
429
430 struct MM1 {
431 /**
432 Second matmul: perform `attn @ V` where `attn` is the attention (not
433 normalized) and stored in shared memory
434 */
435 using GemmType = DefaultGemmType<ArchTag, scalar_t>;
436
437 using OpClass = typename GemmType::OpClass;
438 using DefaultConfig =
439 typename cutlass::gemm::device::DefaultGemmConfiguration<
440 OpClass,
441 ArchTag,
442 scalar_t,
443 scalar_t,
444 output_accum_t, // ElementC
445 accum_t // ElementAccumulator
446 >;
447 static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem
448 static constexpr int kAlignmentB =
449 kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
450 using ThreadblockShape = cutlass::gemm::
451 GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
452 using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
453 using InstructionShape = typename GemmType::InstructionShape;
454
455 using LayoutB = cutlass::layout::RowMajor;
456 using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
457 scalar_t, // ElementA,
458 cutlass::layout::RowMajor, // LayoutA,
459 kAlignmentA,
460 scalar_t, // ElementB,
461 LayoutB, // LayoutB,
462 kAlignmentB,
463 output_accum_t,
464 cutlass::layout::RowMajor, // LayoutC,
465 accum_t,
466 OpClass,
467 ArchTag,
468 ThreadblockShape,
469 WarpShape,
470 typename GemmType::InstructionShape,
471 typename DefaultConfig::EpilogueOutputOp,
472 void, // ThreadblockSwizzle - not used
473 ArchTag::kMinComputeCapability >= 80 && kIsHalf
474 ? 4
475 : DefaultConfig::kStages,
476 false, // SplitKSerial
477 typename GemmType::Operator>;
478
479 using WarpIteratorA = typename cutlass::gemm::threadblock::
480 DefaultWarpIteratorAFromSharedMemory<
481 typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape
482 typename DefaultGemm::Mma::Policy::Operator::InstructionShape,
483 typename DefaultGemm::Mma::Policy::Operator::IteratorA,
484 typename DefaultGemm::Mma::Policy>::WarpIterator;
485 using DefaultMmaFromSmem =
486 typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
487 typename DefaultGemm::Mma,
488 MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK
489 WarpIteratorA,
490 false>; // kScaleOperandA
491 using Mma = typename DefaultMmaFromSmem::Mma;
492 using IteratorB = typename Mma::IteratorB;
493 using WarpCount = typename Mma::WarpCount;
494 static_assert(
495 WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock,
496 "");
497
498 using DefaultEpilogue = typename DefaultGemm::Epilogue;
499 using OutputTileIterator =
500 typename cutlass::epilogue::threadblock::PredicatedTileIterator<
501 typename DefaultEpilogue::OutputTileIterator::ThreadMap,
502 output_t>;
503 using OutputTileIteratorAccum =
504 typename cutlass::epilogue::threadblock::PredicatedTileIterator<
505 typename DefaultEpilogue::OutputTileIterator::ThreadMap,
506 output_accum_t>;
507 };
508
509 static constexpr int64_t kAlignmentQ = MM0::kAlignmentA;
510 static constexpr int64_t kAlignmentK = MM0::kAlignmentB;
511 static constexpr int64_t kAlignmentV = 1;
512
513 // Shared storage - depends on kernel params
514 struct ScalingCoefs {
515 cutlass::Array<accum_t, kQueriesPerBlock> m_prime;
516 cutlass::Array<accum_t, kQueriesPerBlock> s_prime;
517 cutlass::Array<accum_t, kQueriesPerBlock> mi;
518 cutlass::Array<accum_t, kQueriesPerBlock> out_rescale;
519 cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>
520 addition_storage;
521 };
522
523 struct SharedStorageEpilogueAtEnd : ScalingCoefs {
524 struct SharedStorageAfterMM0 {
525 // Everything here might be overwritten during MM0
526 union {
527 typename MM0::BiasLoader::SmemTile bias;
528 typename MM0::AccumulatorSharedStorage si;
529 };
530 typename MM1::Mma::SharedStorage mm1;
531 };
532
533 union {
534 typename MM0::Mma::SharedStorage mm0;
535 SharedStorageAfterMM0 after_mm0;
536 typename MM1::DefaultEpilogue::SharedStorage epilogue;
537 };
538
539 CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
epilogue_shared_storageAttentionKernel::SharedStorageEpilogueAtEnd540 epilogue_shared_storage() {
541 return epilogue;
542 }
543 };
544
545 struct SharedStorageEpilogueInLoop : ScalingCoefs {
546 struct SharedStorageAfterMM0 {
547 // Everything here might be overwritten during MM0
548 union {
549 typename MM0::BiasLoader::SmemTile bias;
550 typename MM0::AccumulatorSharedStorage si;
551 };
552 typename MM1::Mma::SharedStorage mm1;
553 typename MM1::DefaultEpilogue::SharedStorage epilogue;
554 };
555
556 union {
557 typename MM0::Mma::SharedStorage mm0;
558 SharedStorageAfterMM0 after_mm0;
559 };
560
561 CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
epilogue_shared_storageAttentionKernel::SharedStorageEpilogueInLoop562 epilogue_shared_storage() {
563 return after_mm0.epilogue;
564 }
565 };
566
567 using SharedStorage = typename cutlass::platform::conditional<
568 kSingleValueIteration || kKeepOutputInRF,
569 SharedStorageEpilogueAtEnd,
570 SharedStorageEpilogueInLoop>::type;
571
check_supportedAttentionKernel572 static bool __host__ check_supported(Params const& p) {
573 CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ);
574 CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK);
575 CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV);
576 if (kSupportsBias) {
577 CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ);
578 TORCH_CHECK(
579 p.num_batches <= 1 || p.bias_strideB % kAlignmentQ == 0,
580 "attn_bias is not correctly aligned (strideB). ",
581 "attn_bias.stride( 0) = ", p.bias_strideB, ", and should be a "
582 "multiple of ", kAlignmentQ, ".");
583 TORCH_CHECK(
584 p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0,
585 "attn_bias is not correctly aligned (strideH). "
586 "attn_bias.stride(1) = ", p.bias_strideH, ", and should be a "
587 "multiple of ", kAlignmentQ, ".");
588 TORCH_CHECK(
589 p.num_queries <= 1 || p.bias_strideM % kAlignmentQ == 0,
590 "attn_bias is not correctly aligned (strideM). "
591 "attn_bias.stride(2) = ", p.bias_strideM, ", and should be a "
592 "multiple of ", kAlignmentQ, ".");
593 }
594 TORCH_CHECK(
595 p.q_strideM % kAlignmentQ == 0,
596 "query is not correctly aligned (strideM)");
597 TORCH_CHECK(
598 p.k_strideM % kAlignmentK == 0,
599 "key is not correctly aligned (strideM)");
600 TORCH_CHECK(
601 p.v_strideM % kAlignmentV == 0,
602 "value is not correctly aligned (strideM)");
603 TORCH_CHECK(
604 p.num_heads <= 1 || p.q_strideH % kAlignmentQ == 0,
605 "query is not correctly aligned (strideH)");
606 TORCH_CHECK(
607 p.num_heads <= 1 || p.k_strideH % kAlignmentK == 0,
608 "key is not correctly aligned (strideH)");
609 TORCH_CHECK(
610 p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0,
611 "value is not correctly aligned (strideH)");
612 TORCH_CHECK(
613 p.custom_mask_type < NumCustomMaskTypes,
614 "invalid value for `custom_mask_type`");
615 if (p.window_size > 0) {
616 TORCH_CHECK(
617 p.custom_mask_type == CausalFromTopLeft ||
618 p.custom_mask_type == CausalFromBottomRight,
619 "custom_mask_type not supported");
620 }
621 return true;
622 }
623
attention_kernelAttentionKernel624 static void CUTLASS_DEVICE attention_kernel(Params& p) {
625 // In this block, we will only ever:
626 // - read query[query_start:query_end, :]
627 // - write to output[query_start:query_end, :]
628
629 extern __shared__ char smem_buffer[];
630 SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
631 auto& m_prime = shared_storage.m_prime;
632 auto& s_prime = shared_storage.s_prime;
633 auto& mi = shared_storage.mi;
634 auto& out_rescale = shared_storage.out_rescale;
635 const uint32_t query_start = blockIdx.x * kQueriesPerBlock;
636
637 static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
638 if (thread_id() < kQueriesPerBlock) {
639 s_prime[thread_id()] = accum_t(0);
640 out_rescale[thread_id()] = accum_t(1.0);
641 m_prime[thread_id()] =
642 -cutlass::platform::numeric_limits<accum_t>::infinity();
643 mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
644 }
645 typename MM1::Mma::FragmentC accum_o;
646 accum_o.clear();
647
648 auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
649 using OutputTileIterator = typename MM1::OutputTileIterator;
650 return OutputTileIterator(
651 typename OutputTileIterator::Params{(int32_t)p.o_strideM},
652 p.output_ptr,
653 typename OutputTileIterator::TensorCoord{
654 p.num_queries, p.head_dim_value},
655 thread_id(),
656 {0, col});
657 };
658
659 auto createOutputAccumIter = [&](int col) ->
660 typename MM1::OutputTileIteratorAccum {
661 using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
662 return OutputTileIteratorAccum(
663 typename OutputTileIteratorAccum::Params{
664 (int32_t)(p.head_dim_value * p.num_heads)},
665 p.output_accum_ptr,
666 typename OutputTileIteratorAccum::TensorCoord{
667 p.num_queries, p.head_dim_value},
668 thread_id(),
669 {0, col});
670 };
671
672 curandStatePhilox4_32_10_t curand_state_init;
673 if (kSupportsDropout && p.use_dropout) {
674 const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs);
675 if (p.rng_engine_inputs.captured_) {
676 // See Note [Seed and Offset Device]
677 // When we are in cuda graph capture mode the seed and offset are stored
678 // on device We pass in int64_t* seed, and int64_t* offset to act as
679 // scratch space for storing the rng state during the forward pass and
680 // saving for backwards.
681 auto [seed, offset] = seeds;
682 *p.seed = seed;
683 *p.extragraph_offset = offset;
684 }
685 // each element of the attention matrix P with shape
686 // (batch_sz, n_heads, n_queries, n_keys) is associated with a single
687 // offset in RNG sequence. we initialize the RNG state with offset that
688 // starts at the beginning of a (n_queries, n_keys) matrix for this
689 // block's batch_id and head_id
690 // initializing rng state is very expensive, so we run once per kernel,
691 // rather than once per iteration. each iteration takes a copy of the
692 // initialized RNG state and offsets it as needed.
693 curand_init(
694 std::get<0>(seeds),
695 0,
696 std::get<1>(seeds) + p.dropout_batch_head_rng_offset,
697 &curand_state_init);
698 }
699
700 // Iterate through keys
701 for (int32_t iter_key_start = 0; iter_key_start < p.num_keys;
702 iter_key_start += kKeysPerBlock) {
703 if (p.window_size > 0) {
704 // don't compute anything if below attention band
705 if (iter_key_start + kKeysPerBlock <
706 int32_t(query_start + p.causal_diagonal_offset) - p.window_size) {
707 continue;
708 }
709 }
710 int32_t problem_size_0_m =
711 cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries);
712 int32_t problem_size_0_n = cutlass::fast_min(
713 int32_t(kKeysPerBlock), p.num_keys - iter_key_start);
714 int32_t const& problem_size_0_k = p.head_dim;
715 int32_t const& problem_size_1_n = p.head_dim_value;
716 int32_t const& problem_size_1_k = problem_size_0_n;
717
718 auto prologueV = [&](int blockN) {
719 typename MM1::Mma::IteratorB iterator_V(
720 typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)},
721 const_cast<scalar_t*>(p.value_ptr + iter_key_start * p.v_strideM),
722 {problem_size_1_k, problem_size_1_n},
723 thread_id(),
724 cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
725 MM1::Mma::prologue(
726 shared_storage.after_mm0.mm1,
727 iterator_V,
728 thread_id(),
729 problem_size_1_k);
730 };
731
732 __syncthreads(); // Need to have shared memory initialized, and `m_prime`
733 // updated from end of prev iter
734 //
735 // MATMUL: Q.K_t
736 //
737 // Computes the block-matrix product of:
738 // (a) query[query_start:query_end, :]
739 // with
740 // (b) key[iter_key_start:iter_key_start + kKeysPerBlock]
741 // and stores that into `shared_storage.si`
742 //
743
744 // Compute threadblock location
745 cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0};
746
747 cutlass::MatrixCoord tb_offset_A{
748 tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()};
749
750 cutlass::MatrixCoord tb_offset_B{
751 tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN};
752
753 // Construct iterators to A and B operands
754 typename MM0::IteratorA iterator_A(
755 typename MM0::IteratorA::Params(
756 typename MM0::MmaCore::LayoutA(p.q_strideM)),
757 const_cast<scalar_t*>(p.query_ptr),
758 {problem_size_0_m, problem_size_0_k},
759 thread_id(),
760 tb_offset_A);
761
762 typename MM0::IteratorB iterator_B(
763 typename MM0::IteratorB::Params(
764 typename MM0::MmaCore::LayoutB(p.k_strideM)),
765 const_cast<scalar_t*>(p.key_ptr + iter_key_start * p.k_strideM),
766 {problem_size_0_k, problem_size_0_n},
767 thread_id(),
768 tb_offset_B);
769
770 auto my_warp_id = warp_uniform(warp_id());
771 auto my_lane_id = lane_id();
772
773 // Construct thread-scoped matrix multiply
774 typename MM0::Mma mma(
775 shared_storage.mm0, thread_id(), my_warp_id, my_lane_id);
776
777 typename MM0::Mma::FragmentC accum;
778
779 accum.clear();
780
781 auto gemm_k_iterations =
782 (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK;
783
784 // Compute threadblock-scoped matrix multiply-add
785 mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
786 __syncthreads();
787
788 if (kPreloadV) {
789 prologueV(0);
790 }
791
792 typename MM0::Mma::Operator::IteratorC::TensorCoord
793 iteratorC_tile_offset = {
794 (tb_tile_offset.m() * MM0::Mma::WarpCount::kM) +
795 (my_warp_id % MM0::Mma::WarpCount::kM),
796 (tb_tile_offset.n() * MM0::Mma::WarpCount::kN) +
797 (my_warp_id / MM0::Mma::WarpCount::kM)};
798
799 // multiply by scaling factor
800 if (kSupportsBias) {
801 accum =
802 cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.scale, accum);
803 }
804
805 // apply attention bias if applicable
806 if (kSupportsBias && p.attn_bias_ptr != nullptr) {
807 // load bias tile Bij into shared memory
808 typename MM0::BiasLoader::GmemTileIterator bias_iter(
809 {cutlass::layout::RowMajor(p.bias_strideM)},
810 // attn_bias_pointer points to matrix of size (n_queries, n_keys)
811 // for the relevant batch_id and head_id
812 const_cast<scalar_t*>(p.attn_bias_ptr + query_start * p.bias_strideM + iter_key_start),
813 {problem_size_0_m, problem_size_0_n},
814 thread_id());
815 cutlass::TensorRef<scalar_t, cutlass::layout::RowMajor> bias_tensor_ref(
816 shared_storage.after_mm0.bias.data(),
817 cutlass::layout::RowMajor(MM0::ThreadblockShape::kN));
818 typename MM0::BiasLoader::SmemTileIterator smem_tile_iter(
819 bias_tensor_ref, thread_id());
820 MM0::BiasLoader::load(bias_iter, smem_tile_iter);
821
822 // Pij += Bij, Pij is in register fragment and Bij is in shared memory
823 auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
824 my_lane_id, my_warp_id, iteratorC_tile_offset);
825 MM0::AccumLambdaIterator::iterateRows(
826 lane_offset,
827 [&](int accum_m) {},
828 [&](int accum_m, int accum_n, int idx) {
829 if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) {
830 accum[idx] += bias_tensor_ref.at({accum_m, accum_n});
831 }
832 },
833 [&](int accum_m) {});
834 }
835
836 // Mask out last if causal
837 // This is only needed if upper-right corner of current query / key block
838 // intersects the mask Coordinates of upper-right corner of current block
839 // is y=query_start x=min(iter_key_start + kKeysPerBlock, num_keys)) The
840 // first masked element is x = y + offset -> query_start + offset There is
841 // intersection (and we need to mask) if min(iter_key_start +
842 // kKeysPerBlock, num_keys)) >= query_start + offset
843 if (p.custom_mask_type &&
844 cutlass::fast_min(iter_key_start + kKeysPerBlock, p.num_keys) >=
845 (query_start + p.causal_diagonal_offset)) {
846 auto query_start = blockIdx.x * kQueriesPerBlock;
847 auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
848 my_lane_id, my_warp_id, iteratorC_tile_offset);
849 int32_t last_col;
850 MM0::AccumLambdaIterator::iterateRows(
851 lane_offset,
852 [&](int accum_m) {
853 // last absolute col is (last absolute query + offset)
854 // last local col is (last absolute query + offset -
855 // iter_key_start)
856 last_col = query_start + accum_m + p.causal_diagonal_offset -
857 iter_key_start;
858 },
859 [&](int accum_m, int accum_n, int idx) {
860 if (accum_n > last_col) {
861 accum[idx] =
862 -cutlass::platform::numeric_limits<accum_t>::infinity();
863 }
864 },
865 [&](int accum_m) {});
866 }
867
868 // Mask out lower left corner of block if window_size > 0
869 // only required if current block intersects with the lower left corner
870 // block starts at x_lowerleft = iter_key_start // y = query_start +
871 // kQueriesPerBlock first non masked value at this y is : x_first =
872 // query_start + kQueriesPerBlock - window_size mask if x_fist >
873 // x_lowerleft
874
875 if (p.window_size > 0 &&
876 (query_start + p.causal_diagonal_offset +
877 cutlass::fast_min(
878 int32_t(kQueriesPerBlock), int32_t(p.num_queries)) -
879 p.window_size >=
880 iter_key_start)) {
881 auto query_start = blockIdx.x * kQueriesPerBlock;
882 auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
883 my_lane_id, my_warp_id, iteratorC_tile_offset);
884 int32_t first_col;
885 const int32_t offset = query_start + p.causal_diagonal_offset -
886 p.window_size - iter_key_start;
887 MM0::AccumLambdaIterator::iterateRows(
888 lane_offset,
889 [&](int accum_m) { first_col = accum_m + offset; },
890 [&](int accum_m, int accum_n, int idx) {
891 if (accum_n <= first_col) {
892 accum[idx] =
893 -cutlass::platform::numeric_limits<accum_t>::infinity();
894 }
895 },
896 [&](int accum_m) {});
897 // print_warp_accum<MM0::AccumLambdaIterator>(accum, lane_offset, 12,
898 // 12);
899 }
900
901 // Update `mi` from accum stored in registers
902 // Also does accum[i] <- exp(accum[i] - mi)
903 iterative_softmax<typename MM0::Mma::Operator::IteratorC>(
904 accum_o,
905 accum,
906 mi,
907 m_prime,
908 s_prime,
909 out_rescale,
910 shared_storage.addition_storage,
911 my_lane_id,
912 thread_id(),
913 my_warp_id,
914 p.num_keys - iter_key_start,
915 iter_key_start == 0,
916 iteratorC_tile_offset,
917 kSupportsBias ? 1.0f : p.scale);
918
919 // Output results to shared-memory
920 int warp_idx_mn_0 = my_warp_id %
921 (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN);
922 auto output_tile_coords = cutlass::MatrixCoord{
923 warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM,
924 warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM};
925
926 MM0::B2bGemm::accumToSmem(
927 shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords);
928
929 __syncthreads();
930
931 // apply dropout (if applicable) after we've written Pij to smem.
932 // dropout is applied by multiplying each element of Pij by:
933 // - 0 with probability dropout_p
934 // - 1 / (1 - dropout_p) with probability 1 - dropout_p
935 //
936 // for backward purposes we want to be able to map each element of the
937 // attention matrix to the same random uniform number as the one we used
938 // in forward, without needing to use the same iteration order or having
939 // to store the dropout matrix. its possible to do this in registers but
940 // it ends up being very slow because each thread having noncontiguous
941 // strips of the Pij tile means we have to skip around a lot, and also
942 // have to generate a single random number at a time
943 if (kSupportsDropout && p.use_dropout) {
944 auto si = shared_storage.after_mm0.si.accum_ref();
945 // each thread handles a contiguous sequence of elements from Sij, all
946 // coming from the same row. the reason they have to come from the same
947 // row is that the sampling random numbers from a contiguous random
948 // number sequence is much more efficient than jumping around, and the
949 // linear offset of each element of S (the global matrix) maps to an
950 // offset in a random number sequence. for S, the end of a row and the
951 // beginning of the next have adjacent offsets, but for Sij, this is not
952 // necessarily the case.
953 const int num_threads = blockDim.x * blockDim.y * blockDim.z;
954 const int threads_per_row =
955 cutlass::fast_min(num_threads / problem_size_0_m, problem_size_0_n);
956 const int elts_per_thread = cutlass::round_nearest(
957 cutlass::ceil_div(problem_size_0_n, threads_per_row), 4);
958
959 const int thread_i = thread_id() / threads_per_row;
960 const int thread_start_j =
961 (thread_id() % threads_per_row) * elts_per_thread;
962
963 if (thread_i < problem_size_0_m && thread_start_j < problem_size_0_n) {
964 curandStatePhilox4_32_10_t curand_state = curand_state_init;
965 skipahead(
966 static_cast<unsigned long long>(
967 (query_start + thread_i) * p.num_keys_absolute +
968 (iter_key_start + thread_start_j)),
969 &curand_state);
970 const float dropout_scale = 1.0 / (1.0 - p.dropout_prob);
971
972 // apply dropout scaling to elements this thread is responsible for,
973 // in chunks of 4
974 for (int sij_start_col_idx = thread_start_j; sij_start_col_idx <
975 cutlass::fast_min(thread_start_j + elts_per_thread,
976 problem_size_0_n);
977 sij_start_col_idx += 4) {
978 const float4 rand_uniform_quad = curand_uniform4(&curand_state);
979
980 CUTLASS_PRAGMA_UNROLL
981 for (int quad_idx = 0; quad_idx < 4; ++quad_idx) {
982 si.at({thread_i, sij_start_col_idx + quad_idx}) *=
983 static_cast<scalar_t>(
984 dropout_scale *
985 ((&rand_uniform_quad.x)[quad_idx] > p.dropout_prob));
986 }
987 }
988 }
989 __syncthreads(); // p.use_dropout should have same value kernel-wide
990 }
991
992 //
993 // MATMUL: Attn . V
994 // Run the matmul `attn @ V` for a block of attn and V.
995 // `attn` is read from shared memory (in `shared_storage_si`)
996 // `V` is read from global memory (with iterator_B)
997 //
998
999 const int64_t nBlockN = kSingleValueIteration
1000 ? 1
1001 : ceil_div(
1002 (int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN));
1003 for (int blockN = 0; blockN < nBlockN; ++blockN) {
1004 int gemm_k_iterations =
1005 (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK;
1006
1007 // Compute threadblock-scoped matrix multiply-add and store it in accum
1008 // (in registers)
1009 if (!kPreloadV) {
1010 __syncthreads(); // we share shmem between mma and epilogue
1011 }
1012
1013 typename MM1::Mma::IteratorB iterator_V(
1014 typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)},
1015 const_cast<scalar_t*>(p.value_ptr + iter_key_start * p.v_strideM),
1016 {problem_size_1_k, problem_size_1_n},
1017 thread_id(),
1018 cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
1019 typename MM1::Mma mma_pv(
1020 // operand A: Pij_dropped in shared memory
1021 shared_storage.after_mm0.si.accum_ref(),
1022 // operand B: shared memory staging area for Vj, which is loaded
1023 // from global memory
1024 shared_storage.after_mm0.mm1.operand_B_ref(),
1025 (int)thread_id(),
1026 (int)my_warp_id,
1027 (int)my_lane_id);
1028 mma_pv.set_prologue_done(kPreloadV);
1029 if (!kKeepOutputInRF) {
1030 accum_o.clear();
1031 }
1032 mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o);
1033 __syncthreads();
1034
1035 if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) {
1036 prologueV(blockN + 1);
1037 }
1038
1039 if (!kKeepOutputInRF) {
1040 int first_key = 0;
1041 if (p.window_size > 0) {
1042 first_key = (cutlass::fast_max(
1043 int(query_start + p.causal_diagonal_offset) -
1044 p.window_size + 1,
1045 0) /
1046 kKeysPerBlock) *
1047 kKeysPerBlock;
1048 }
1049
1050 // int first_key_block = 0;
1051 // MM1::Mma::drain_cp_asyncs(); # TODO figure out if this is needed for correctness
1052 DISPATCH_BOOL(
1053 iter_key_start == first_key, kIsFirst, ([&] {
1054 DISPATCH_BOOL(
1055 (iter_key_start + kKeysPerBlock) >= p.num_keys,
1056 kIsLast,
1057 ([&] {
1058 using DefaultEpilogue = typename MM1::DefaultEpilogue;
1059 using DefaultOp =
1060 typename MM1::DefaultConfig::EpilogueOutputOp;
1061 using ElementCompute = typename DefaultOp::ElementCompute;
1062 using EpilogueOutputOp = typename cutlass::epilogue::
1063 thread::MemoryEfficientAttentionNormalize<
1064 typename cutlass::platform::conditional<
1065 kIsLast,
1066 output_t,
1067 output_accum_t>::type,
1068 output_accum_t,
1069 DefaultOp::kCount,
1070 typename DefaultOp::ElementAccumulator,
1071 ElementCompute,
1072 kIsFirst,
1073 kIsLast,
1074 cutlass::Array<ElementCompute, kQueriesPerBlock>>;
1075 using Epilogue = typename cutlass::epilogue::threadblock::
1076 EpiloguePipelined<
1077 typename DefaultEpilogue::Shape,
1078 typename MM1::Mma::Operator,
1079 DefaultEpilogue::kPartitionsK,
1080 typename cutlass::platform::conditional<
1081 kIsLast,
1082 typename MM1::OutputTileIterator,
1083 typename MM1::OutputTileIteratorAccum>::type,
1084 typename DefaultEpilogue::
1085 AccumulatorFragmentIterator,
1086 typename DefaultEpilogue::WarpTileIterator,
1087 typename DefaultEpilogue::SharedLoadIterator,
1088 EpilogueOutputOp,
1089 typename DefaultEpilogue::Padding,
1090 DefaultEpilogue::kFragmentsPerIteration,
1091 true, // IterationsUnroll
1092 typename MM1::OutputTileIteratorAccum // Read
1093 // iterator
1094 >;
1095
1096 int col = blockN * MM1::Mma::Shape::kN;
1097 auto source_iter = createOutputAccumIter(col);
1098 auto dest_iter = call_conditional<
1099 kIsLast,
1100 decltype(createOutputIter),
1101 decltype(createOutputAccumIter)>::
1102 apply(createOutputIter, createOutputAccumIter, col);
1103 EpilogueOutputOp rescale(s_prime, out_rescale);
1104 Epilogue epilogue(
1105 shared_storage.epilogue_shared_storage(),
1106 thread_id(),
1107 my_warp_id,
1108 my_lane_id);
1109 epilogue(rescale, dest_iter, accum_o, source_iter);
1110 }));
1111 }));
1112 if (!kSingleValueIteration) {
1113 __syncthreads();
1114 }
1115 }
1116 }
1117 __syncthreads(); // we modify `m_prime` after
1118 }
1119
1120 if (kKeepOutputInRF) {
1121 constexpr bool kIsFirst = true;
1122 constexpr bool kIsLast = true;
1123 using DefaultEpilogue = typename MM1::DefaultEpilogue;
1124 using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
1125 using ElementCompute = typename DefaultOp::ElementCompute;
1126 using EpilogueOutputOp =
1127 typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize<
1128 output_t, // output
1129 output_accum_t, // source
1130 DefaultOp::kCount,
1131 typename DefaultOp::ElementAccumulator, // accum
1132 output_accum_t, // compute
1133 kIsFirst,
1134 kIsLast,
1135 cutlass::Array<ElementCompute, kQueriesPerBlock>>;
1136 using Epilogue =
1137 typename cutlass::epilogue::threadblock::EpiloguePipelined<
1138 typename DefaultEpilogue::Shape,
1139 typename MM1::Mma::Operator,
1140 DefaultEpilogue::kPartitionsK,
1141 typename MM1::OutputTileIterator, // destination
1142 typename DefaultEpilogue::AccumulatorFragmentIterator,
1143 typename DefaultEpilogue::WarpTileIterator,
1144 typename DefaultEpilogue::SharedLoadIterator,
1145 EpilogueOutputOp,
1146 typename DefaultEpilogue::Padding,
1147 DefaultEpilogue::kFragmentsPerIteration,
1148 true, // IterationsUnroll
1149 typename MM1::OutputTileIteratorAccum // source tile
1150 >;
1151 auto dest_iter = createOutputIter(0);
1152 EpilogueOutputOp rescale(s_prime, out_rescale);
1153 Epilogue epilogue(
1154 shared_storage.epilogue_shared_storage(),
1155 thread_id(),
1156 warp_id(),
1157 lane_id());
1158 epilogue(rescale, dest_iter, accum_o);
1159 }
1160
1161 // 7. Calculate logsumexp
1162 // To make the backward easier, we pad logsumexp with `inf`
1163 // this avoids a few bound checks, and is not more expensive during fwd
1164 static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
1165 if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
1166 auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
1167 constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
1168 if (thread_id() < p.num_queries) {
1169 // We set fully masked out rows to 0, the sumexp for masked out rows will be 0
1170 // We update it to be 1 prior to calling log so that log(1) = 0
1171 s_prime[thread_id()] = (s_prime[thread_id()] == 0) ? 1: s_prime[thread_id()];
1172 mi[thread_id()] = (mi[thread_id()] == -cutlass::platform::numeric_limits<accum_t>::infinity()) ? 0: mi[thread_id()];
1173 p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) +
1174 cutlass::fast_log(accum_t(s_prime[thread_id()]));
1175 } else if (thread_id() < lse_dim) {
1176 p.logsumexp_ptr[thread_id()] =
1177 cutlass::platform::numeric_limits<accum_t>::infinity();
1178 }
1179 }
1180 }
1181
1182 template <typename WarpIteratorC>
iterative_softmaxAttentionKernel1183 CUTLASS_DEVICE static void iterative_softmax(
1184 typename WarpIteratorC::Fragment& frag_o, // output so far
1185 typename WarpIteratorC::Fragment& frag,
1186 cutlass::Array<accum_t, kQueriesPerBlock>& mi,
1187 cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
1188 cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
1189 cutlass::Array<accum_t, kQueriesPerBlock>& out_rescale,
1190 cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>&
1191 addition_storage,
1192 int8_t lane_id,
1193 int8_t thread_id,
1194 int8_t warp_id,
1195 int max_col,
1196 bool is_first,
1197 typename WarpIteratorC::TensorCoord const& tile_offset,
1198 float scaling) {
1199 /* Iterates on the accumulator and corresponding position on result matrix
1200
1201 (1) Update `mi[r]` to the max value of the row `r`
1202 (2) In a second iteration do the following:
1203 (a) accum <- exp(accum - mi)
1204 (b) m_prime <- exp(m_prime - mi)
1205 (c) s_prime <- s_prime * m_prime + sum(accum)
1206
1207 All of this is done on registers, before we store all of this
1208 on shared memory for the next matmul with Value.
1209 */
1210 using Fragment = typename WarpIteratorC::Fragment;
1211 using LambdaIterator = typename DefaultMmaAccumLambdaIterator<
1212 WarpIteratorC,
1213 accum_t,
1214 kWarpSize>::Iterator;
1215 // Convert to `accum_t` (rather than double)
1216 constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
1217
1218 static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, "");
1219 static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock;
1220
1221 frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
1222
1223 auto lane_offset =
1224 LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
1225
1226 // First update `mi` to the max per-row
1227 {
1228 accum_t max;
1229 LambdaIterator::iterateRows(
1230 lane_offset,
1231 [&](int accum_m) {
1232 max = -cutlass::platform::numeric_limits<accum_t>::infinity();
1233 },
1234 [&](int accum_m, int accum_n, int idx) {
1235 if (accum_n < max_col) {
1236 max = cutlass::fast_max(max, frag[idx]);
1237 }
1238 },
1239 [&](int accum_m) {
1240 // Having 4x atomicMax seems faster than reduce within warp
1241 // first...
1242 atomicMaxFloat(&mi[accum_m], max);
1243 });
1244 }
1245
1246 // Make sure we all share the update values for `mi`
1247 __syncthreads();
1248
1249 // Doing this `exp` is quite expensive. Let's
1250 // split it across the warps
1251 bool restore_mi_to_minus_inf = false;
1252 if (lane_id < kLinesPerWarp) {
1253 int id = warp_id * kLinesPerWarp + lane_id;
1254 auto m_prime_id = m_prime[id];
1255 auto mi_id = mi[id];
1256 bool changed = m_prime_id < mi_id; // `false` if both are -inf
1257 if (changed) {
1258 auto m_prime_exp = exp2f(m_prime_id - mi_id);
1259 out_rescale[id] = m_prime_exp;
1260 s_prime[id] *= m_prime_exp;
1261 } else {
1262 // Only when bias is enabled, it's possible that all the first values
1263 // of attention are masked to `-inf`. In that case we want to avoid
1264 // `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0
1265 if (kSupportsBias &&
1266 mi_id == -cutlass::platform::numeric_limits<accum_t>::infinity()) {
1267 restore_mi_to_minus_inf = true;
1268 mi[id] = 0.0f;
1269 }
1270 out_rescale[id] = 1.0f;
1271 }
1272 }
1273 __syncthreads(); // Update output fragments
1274 if (kKeepOutputInRF && !is_first) {
1275 accum_t line_rescale;
1276 LambdaIterator::iterateRows(
1277 lane_offset,
1278 [&](int accum_m) { line_rescale = out_rescale[accum_m]; },
1279 [&](int accum_m, int accum_n, int idx) {
1280 frag_o[idx] = frag_o[idx] * line_rescale;
1281 },
1282 [&](int accum_m) {});
1283 }
1284 // Update accum_m, accum_n, ...
1285 {
1286 accum_t mi_row, total_row;
1287 LambdaIterator::iterateRows(
1288 lane_offset,
1289 [&](int accum_m) { mi_row = mi[accum_m]; },
1290 [&](int accum_m, int accum_n, int idx) {
1291 frag[idx] =
1292 (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
1293 },
1294 [&](int accum_m) {});
1295 LambdaIterator::iterateRows(
1296 lane_offset,
1297 [&](int accum_m) { total_row = 0.0; },
1298 [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; },
1299 [&](int accum_m) {
1300 if (LambdaIterator::reduceSameRow(
1301 lane_id, total_row, [](accum_t a, accum_t b) {
1302 return a + b;
1303 })) {
1304 // NOTE: we could atomically add `total_row` to `s_prime`, but
1305 // it's faster (and deterministic) to avoid atomics here
1306 addition_storage
1307 [accum_m + kQueriesPerBlock * tile_offset.column()] =
1308 total_row;
1309 }
1310 });
1311 }
1312 __syncthreads();
1313 if (lane_id < kLinesPerWarp) {
1314 int id = warp_id * kLinesPerWarp + lane_id;
1315 accum_t total_row = s_prime[id];
1316 if (restore_mi_to_minus_inf) {
1317 // Restore `mi`, see above when we set `restore_mi_to_minus_inf=true`
1318 mi[id] = -cutlass::platform::numeric_limits<accum_t>::infinity();
1319 } else {
1320 m_prime[id] = mi[id];
1321 }
1322 CUTLASS_PRAGMA_UNROLL
1323 for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
1324 total_row += addition_storage[id + kQueriesPerBlock * i];
1325 }
1326 s_prime[id] = total_row;
1327 }
1328 }
1329
lane_idAttentionKernel1330 static CUTLASS_DEVICE int8_t lane_id() {
1331 return threadIdx.x;
1332 }
warp_idAttentionKernel1333 static CUTLASS_DEVICE int8_t warp_id() {
1334 return threadIdx.y;
1335 }
thread_idAttentionKernel1336 static CUTLASS_DEVICE int16_t thread_id() {
1337 return threadIdx.x + threadIdx.y * blockDim.x;
1338 }
1339 };
1340
1341 template <typename AK>
__launch_bounds__(AK::kNumThreads,AK::kMinBlocksPerSm)1342 __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
1343 attention_kernel_batched_impl(typename AK::Params p) {
1344 if (!p.advance_to_block()) {
1345 return;
1346 }
1347 AK::attention_kernel(p);
1348 }
1349
1350 template <typename AK>
1351 __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
1352 attention_kernel_batched(typename AK::Params params);
1353
1354 } // namespace PyTorchMemEffAttention
1355