xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 #pragma once
9 
10 #include <ATen/cuda/PhiloxUtils.cuh>
11 #include <c10/util/Exception.h>
12 
13 #include <curand_kernel.h>
14 #include <cmath>
15 #include <vector>
16 
17 #include <cutlass/bfloat16.h>
18 #include <cutlass/fast_math.h>
19 #include <cutlass/gemm/gemm.h>
20 #include <cutlass/layout/matrix.h>
21 #include <cutlass/layout/vector.h>
22 #include <cutlass/matrix.h>
23 #include <cutlass/numeric_types.h>
24 #include <cutlass/tensor_ref.h>
25 
26 #include <cutlass/epilogue/threadblock/default_epilogue_simt.h>
27 #include <cutlass/epilogue/threadblock/default_epilogue_tensor_op.h>
28 #include <cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h>
29 
30 #include <cutlass/gemm/device/default_gemm_configuration.h>
31 #include <cutlass/gemm/kernel/default_gemm.h>
32 #include <cutlass/gemm/threadblock/default_mma.h>
33 #include <cutlass/gemm/threadblock/default_mma_core_simt.h>
34 #include <cutlass/gemm/threadblock/default_mma_core_sm70.h>
35 #include <cutlass/gemm/threadblock/default_mma_core_sm75.h>
36 #include <cutlass/gemm/threadblock/default_mma_core_sm80.h>
37 #include <cutlass/gemm/threadblock/threadblock_swizzle.h>
38 #include <cutlass/matrix_shape.h>
39 #include <cutlass/platform/platform.h>
40 #include <cutlass/transform/threadblock/predicated_tile_iterator.h>
41 
42 #include <ATen/native/transformers/cuda/mem_eff_attention/debug_utils.h>
43 #include <ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_pipelined.h>
44 #include <ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h>
45 
46 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma.h>
47 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/find_default_mma.h>
48 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h>
49 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h>
50 #include <ATen/native/transformers/cuda/mem_eff_attention/transform/tile_smem_loader.h>
51 
52 #include <cinttypes>
53 
54 using namespace gemm_kernel_utils;
55 
56 namespace PyTorchMemEffAttention {
57 namespace {
58 template <typename scalar_t, typename Arch>
getWarpsPerSmFw()59 constexpr int getWarpsPerSmFw() {
60   return (
61       Arch::kMinComputeCapability >= 80 &&
62               !cutlass::platform::is_same<scalar_t, float>::value
63           ? 16
64           : 12);
65 }
atomicMaxFloat(float * addr,float value)66 static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
67   // source: https://stackoverflow.com/a/51549250
68   return !signbit(value)
69              ? __int_as_float(atomicMax((int *)addr, __float_as_int(value)))
70              : __uint_as_float(
71                    atomicMin((unsigned int *)addr, __float_as_uint(value)));
72 }
73 } // namespace
74 
75 template <
76     // The datatype of Q/K/V
77     typename scalar_t_,
78     // Architecture we are targeting (eg `cutlass::arch::Sm80`)
79     typename ArchTag,
80     // If Q/K/V are correctly aligned in memory and we can run a fast kernel
81     bool isAligned_,
82     int kQueriesPerBlock_,
83     int kKeysPerBlock_,
84     // upperbound on `max(value.shape[-1], query.shape[-1])`
85     int kMaxK_ = (int)cutlass::platform::numeric_limits<uint32_t>::max(),
86     // This is quite slower on V100 for some reason
87     // Set to false if you know at compile-time you will never need dropout
88     bool kSupportsDropout_ = true,
89     bool kSupportsBias_ = true>
90 struct AttentionKernel {
91   enum CustomMaskType {
92     NoCustomMask = 0,
93     CausalFromTopLeft = 1,
94     CausalFromBottomRight = 2,
95     NumCustomMaskTypes,
96   };
97 
98   using scalar_t = scalar_t_;
99   using accum_t = float;
100   using lse_scalar_t = float;
101   using output_t = scalar_t;
102   // Accumulator between 2 iterations
103   // Using `accum_t` improves perf on f16 at the cost of
104   // numerical errors
105   using output_accum_t = accum_t;
106   static constexpr bool kSupportsDropout = kSupportsDropout_;
107   static constexpr bool kSupportsBias = kSupportsBias_;
108   static constexpr int kKeysPerBlock = kKeysPerBlock_;
109   static constexpr int kQueriesPerBlock = kQueriesPerBlock_;
110   static constexpr int kMaxK = kMaxK_;
111   static constexpr bool kIsAligned = isAligned_;
112   static constexpr bool kSingleValueIteration = kMaxK <= kKeysPerBlock;
113   static constexpr int32_t kAlignLSE = 32; // block size of backward
114   static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value == 16;
115   static constexpr bool kPreloadV =
116       ArchTag::kMinComputeCapability >= 80 && kIsHalf;
117   static constexpr bool kKeepOutputInRF = kSingleValueIteration;
118   static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
119       !cutlass::platform::is_same<output_accum_t, output_t>::value;
120 
121   static_assert(kQueriesPerBlock % 32 == 0, "");
122   static_assert(kKeysPerBlock % 32 == 0, "");
123   static constexpr int kNumWarpsPerBlock =
124       kQueriesPerBlock * kKeysPerBlock / (32 * 32);
125   static constexpr int kWarpSize = 32;
126 
127   // Launch bounds
128   static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock;
129   static constexpr int kMinBlocksPerSm =
130       getWarpsPerSmFw<scalar_t, ArchTag>() / kNumWarpsPerBlock;
131 
132   struct Params {
133     // Input tensors
134     const scalar_t* query_ptr = nullptr; // [num_queries, num_heads, head_dim]
135     const scalar_t* key_ptr = nullptr; // [num_keys, num_heads, head_dim]
136     const scalar_t* value_ptr = nullptr; // [num_keys, num_heads, head_dim_value]
137     const scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys]
138     const int32_t* seqstart_q_ptr = nullptr;
139     const int32_t* seqstart_k_ptr = nullptr;
140 
141     const int32_t* seqlen_k_ptr = nullptr;
142     uint32_t causal_diagonal_offset = 0;
143 
144     // Output tensors
145     output_t* output_ptr = nullptr; // [num_queries, num_heads, head_dim_value]
146     // [num_queries, num_heads, head_dim_value]
147     output_accum_t* output_accum_ptr = nullptr;
148     // [num_heads, num_queries] - can be null
149     lse_scalar_t* logsumexp_ptr = nullptr;
150 
151     // Sliding window. ignored if == 0
152     int32_t window_size = 0;
153 
154     // Scale
155     accum_t scale = 0.0;
156 
157     // Dimensions/strides
158     int32_t head_dim = 0;
159     int32_t head_dim_value = 0;
160     int32_t num_queries = 0;
161     int32_t num_keys = 0;
162     int32_t num_keys_absolute = 0;
163 
164     uint8_t custom_mask_type = NoCustomMask;
165 
166     int32_t q_strideM = 0;
167     int32_t k_strideM = 0;
168     int32_t v_strideM = 0;
169     int32_t bias_strideM = 0;
170 
171     int32_t o_strideM = 0;
172 
173     // Everything below is only used in `advance_to_block`
174     // and shouldn't use registers
175     int32_t q_strideH = 0;
176     int32_t k_strideH = 0;
177     int32_t v_strideH = 0;
178     int64_t bias_strideH = 0;
179 
180     int64_t q_strideB = 0;
181     int64_t k_strideB = 0;
182     int64_t v_strideB = 0;
183     int64_t bias_strideB = 0;
184 
185     int32_t num_batches = 0;
186     int32_t num_heads = 0;
187 
188     // dropout
189     bool use_dropout = false;
190     unsigned long long dropout_batch_head_rng_offset = 0;
191     float dropout_prob = 0.0f;
192     at::PhiloxCudaState rng_engine_inputs = at::PhiloxCudaState(0, 0);
193     int64_t* extragraph_offset = nullptr;
194     int64_t* seed = nullptr;
195 
196     // Moves pointers to what we should process
197     // Returns "false" if there is no work to do
advance_to_blockAttentionKernel::Params198     CUTLASS_DEVICE bool advance_to_block() {
199       auto batch_id = blockIdx.z;
200       auto head_id = blockIdx.y;
201       auto query_start = blockIdx.x * kQueriesPerBlock;
202 
203       auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE;
204 
205       if (kSupportsDropout) {
206         dropout_batch_head_rng_offset =
207             batch_id * num_heads * num_queries * num_keys +
208             head_id * num_queries * num_keys;
209       }
210 
211       int64_t q_start = 0, k_start = 0;
212       // Advance to current batch - in case of different sequence lengths
213       if (seqstart_q_ptr != nullptr) {
214         assert(seqstart_k_ptr != nullptr);
215         seqstart_q_ptr += batch_id;
216 
217         q_start = seqstart_q_ptr[0];
218         int64_t q_next_start = seqstart_q_ptr[1];
219         int64_t k_end;
220         seqstart_k_ptr += batch_id;
221 
222         if (seqlen_k_ptr) {
223           k_start = seqstart_k_ptr[0];
224           k_end = k_start + seqlen_k_ptr[batch_id];
225         } else {
226           k_start = seqstart_k_ptr[0];
227           k_end = seqstart_k_ptr[1];
228         }
229 
230         num_queries = q_next_start - q_start;
231         num_keys = k_end - k_start;
232 
233         if (query_start >= num_queries) {
234           return false;
235         }
236       } else {
237         query_ptr += batch_id * q_strideB;
238         key_ptr += batch_id * k_strideB;
239         value_ptr += batch_id * v_strideB;
240         output_ptr += int64_t(batch_id * num_queries) * o_strideM;
241         if (output_accum_ptr != nullptr) {
242           output_accum_ptr +=
243               int64_t(batch_id * num_queries) * (head_dim_value * num_heads);
244         }
245         q_start = 0;
246         k_start = 0;
247       }
248 
249       // Advance to the current batch / head / query_start
250       query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH;
251       key_ptr += k_start * k_strideM + head_id * k_strideH;
252 
253       value_ptr += k_start * v_strideM + head_id * v_strideH;
254       output_ptr +=
255           int64_t(q_start + query_start) * o_strideM + head_id * head_dim_value;
256 
257       if (kSupportsBias && attn_bias_ptr != nullptr) {
258         attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH);
259       }
260       if (output_accum_ptr != nullptr) {
261         output_accum_ptr +=
262             int64_t(q_start + query_start) * (head_dim_value * num_heads) +
263             head_id * head_dim_value;
264       } else {
265         // Accumulate directly in the destination buffer (eg for f32)
266         output_accum_ptr = (accum_t*)output_ptr;
267       }
268 
269       if (logsumexp_ptr != nullptr) {
270         // lse[batch_id, head_id, query_start]
271         logsumexp_ptr +=
272             batch_id * lse_dim * num_heads + head_id * lse_dim + query_start;
273       }
274 
275       // Custom masking
276       if (custom_mask_type == CausalFromBottomRight) {
277         causal_diagonal_offset = num_keys - num_queries;
278       }
279       // We use num_keys_absolute to index into the rng_state
280       // We need this index to match between forward and backwards
281       num_keys_absolute = num_keys;
282       if (custom_mask_type == CausalFromTopLeft ||
283           custom_mask_type == CausalFromBottomRight) {
284         // the bottom row of the current block is query_start + kQueriesPerBlock
285         // the last active key is then query_start + causal_diagonal_offset +
286         // kQueriesPerBlock so num_keys is the min between actual num_keys and
287         // this to avoid extra computations
288         num_keys = cutlass::fast_min(
289             int32_t(query_start + causal_diagonal_offset + kQueriesPerBlock),
290             num_keys);
291       }
292 
293       num_queries -= query_start;
294       num_batches = 0; // no longer used after
295 
296       // If num_queries == 1, and there is only one key head we're wasting
297       // 15/16th of tensor core compute In that case :
298       //  - we only launch kernels for head_id % kQueriesPerBlock == 0
299       //  - we iterate over heads instead of queries (strideM = strideH)
300       if (num_queries == 1 && k_strideH == 0 && v_strideH == 0 &&
301           logsumexp_ptr == nullptr && window_size == 0) {
302         if (head_id % kQueriesPerBlock != 0) {
303           return false;
304         }
305         q_strideM = q_strideH;
306         bias_strideM = bias_strideH;
307         num_queries = num_heads;
308         num_heads = 1; // unused but here for intent
309         // remove causal since n_query = 1
310         // otherwise, offset would change with head !
311         custom_mask_type = NoCustomMask;
312         o_strideM = head_dim_value;
313       }
314 
315       // Make sure the compiler knows these variables are the same on all
316       // the threads of the warp.
317       // Only worth doing if they could have been modified above.
318       query_ptr = warp_uniform(query_ptr);
319       key_ptr = warp_uniform(key_ptr);
320       value_ptr = warp_uniform(value_ptr);
321       if (kSupportsBias) {
322         attn_bias_ptr = warp_uniform(attn_bias_ptr);
323       }
324       output_ptr = warp_uniform(output_ptr);
325       output_accum_ptr = warp_uniform(output_accum_ptr);
326       logsumexp_ptr = warp_uniform(logsumexp_ptr);
327       num_queries = warp_uniform(num_queries);
328       num_keys = warp_uniform(num_keys);
329       num_heads = warp_uniform(num_heads);
330       o_strideM = warp_uniform(o_strideM);
331       custom_mask_type = warp_uniform(custom_mask_type);
332       return true;
333     }
334 
getBlocksGridAttentionKernel::Params335     __host__ dim3 getBlocksGrid() const {
336       return dim3(
337           ceil_div(num_queries, (int32_t)kQueriesPerBlock),
338           num_heads,
339           num_batches);
340     }
341 
getThreadsGridAttentionKernel::Params342     __host__ dim3 getThreadsGrid() const {
343       return dim3(kWarpSize, kNumWarpsPerBlock, 1);
344     }
345   };
346 
347   struct MM0 {
348     /*
349       In this first matmul, we compute a block of `Q @ K.T`.
350       While the calculation result is still hot in registers, we update
351       `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value
352       into a shared-memory ("AccumulatorSharedStorage") that is used later as
353       operand A for the second matmul (see MM1)
354     */
355     using GemmType = DefaultGemmType<ArchTag, scalar_t>;
356 
357     using OpClass = typename GemmType::OpClass;
358     using DefaultConfig =
359         typename cutlass::gemm::device::DefaultGemmConfiguration<
360             OpClass,
361             ArchTag,
362             scalar_t,
363             scalar_t,
364             scalar_t, // ElementC
365             accum_t // ElementAccumulator
366             >;
367     static constexpr int kAlignmentA =
368         kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment;
369     static constexpr int kAlignmentB =
370         kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
371     using ThreadblockShape = cutlass::gemm::
372         GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
373     using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
374     using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma<
375         scalar_t, // ElementA,
376         cutlass::layout::RowMajor, // LayoutA,
377         kAlignmentA,
378         scalar_t, // ElementB,
379         cutlass::layout::ColumnMajor, // LayoutB,
380         kAlignmentB,
381         accum_t,
382         cutlass::layout::RowMajor, // LayoutC,
383         OpClass,
384         ArchTag, // ArchTag
385         ThreadblockShape, // ThreadblockShape
386         WarpShape, // WarpShape
387         typename GemmType::InstructionShape, // InstructionShape
388         ArchTag::kMinComputeCapability >= 80 && kIsHalf
389             ? 4
390             : DefaultConfig::kStages,
391         typename GemmType::Operator // Operator
392         >::DefaultMma;
393     using MmaCore = typename DefaultMma::MmaCore;
394     using IteratorA = typename DefaultMma::IteratorA;
395     using IteratorB = typename DefaultMma::IteratorB;
396     using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma;
397     using Mma = typename cutlass::platform::conditional<
398         kSingleValueIteration,
399         typename MakeCustomMma<DefaultThreadblockMma, kMaxK>::Mma,
400         DefaultThreadblockMma>::type;
401     using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
402         typename Mma::Operator::IteratorC,
403         accum_t,
404         kWarpSize>::Iterator;
405     static_assert(
406         MmaCore::WarpCount::kM * MmaCore::WarpCount::kN *
407                 MmaCore::WarpCount::kK ==
408             kNumWarpsPerBlock,
409         "");
410 
411     // used for efficient load of bias tile Bij from global to shared memory
412     using BiasLoader = TileSmemLoader<
413         scalar_t,
414         cutlass::MatrixShape<kQueriesPerBlock, kKeysPerBlock>,
415         MmaCore::kThreads,
416         // input restriction: kv_len has to be a multiple of this value
417         128 / cutlass::sizeof_bits<scalar_t>::value>;
418 
419     // Epilogue to store to shared-memory in a format that we can use later for
420     // the second matmul
421     using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
422         typename Mma::Operator::IteratorC,
423         typename Mma::Operator,
424         scalar_t,
425         WarpShape,
426         ThreadblockShape>;
427     using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
428   };
429 
430   struct MM1 {
431     /**
432       Second matmul: perform `attn @ V` where `attn` is the attention (not
433       normalized) and stored in shared memory
434     */
435     using GemmType = DefaultGemmType<ArchTag, scalar_t>;
436 
437     using OpClass = typename GemmType::OpClass;
438     using DefaultConfig =
439         typename cutlass::gemm::device::DefaultGemmConfiguration<
440             OpClass,
441             ArchTag,
442             scalar_t,
443             scalar_t,
444             output_accum_t, // ElementC
445             accum_t // ElementAccumulator
446             >;
447     static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem
448     static constexpr int kAlignmentB =
449         kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
450     using ThreadblockShape = cutlass::gemm::
451         GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
452     using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
453     using InstructionShape = typename GemmType::InstructionShape;
454 
455     using LayoutB = cutlass::layout::RowMajor;
456     using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
457         scalar_t, // ElementA,
458         cutlass::layout::RowMajor, // LayoutA,
459         kAlignmentA,
460         scalar_t, // ElementB,
461         LayoutB, // LayoutB,
462         kAlignmentB,
463         output_accum_t,
464         cutlass::layout::RowMajor, // LayoutC,
465         accum_t,
466         OpClass,
467         ArchTag,
468         ThreadblockShape,
469         WarpShape,
470         typename GemmType::InstructionShape,
471         typename DefaultConfig::EpilogueOutputOp,
472         void, // ThreadblockSwizzle - not used
473         ArchTag::kMinComputeCapability >= 80 && kIsHalf
474             ? 4
475             : DefaultConfig::kStages,
476         false, // SplitKSerial
477         typename GemmType::Operator>;
478 
479     using WarpIteratorA = typename cutlass::gemm::threadblock::
480         DefaultWarpIteratorAFromSharedMemory<
481             typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape
482             typename DefaultGemm::Mma::Policy::Operator::InstructionShape,
483             typename DefaultGemm::Mma::Policy::Operator::IteratorA,
484             typename DefaultGemm::Mma::Policy>::WarpIterator;
485     using DefaultMmaFromSmem =
486         typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
487             typename DefaultGemm::Mma,
488             MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK
489             WarpIteratorA,
490             false>; // kScaleOperandA
491     using Mma = typename DefaultMmaFromSmem::Mma;
492     using IteratorB = typename Mma::IteratorB;
493     using WarpCount = typename Mma::WarpCount;
494     static_assert(
495         WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock,
496         "");
497 
498     using DefaultEpilogue = typename DefaultGemm::Epilogue;
499     using OutputTileIterator =
500         typename cutlass::epilogue::threadblock::PredicatedTileIterator<
501             typename DefaultEpilogue::OutputTileIterator::ThreadMap,
502             output_t>;
503     using OutputTileIteratorAccum =
504         typename cutlass::epilogue::threadblock::PredicatedTileIterator<
505             typename DefaultEpilogue::OutputTileIterator::ThreadMap,
506             output_accum_t>;
507   };
508 
509   static constexpr int64_t kAlignmentQ = MM0::kAlignmentA;
510   static constexpr int64_t kAlignmentK = MM0::kAlignmentB;
511   static constexpr int64_t kAlignmentV = 1;
512 
513   // Shared storage - depends on kernel params
514   struct ScalingCoefs {
515     cutlass::Array<accum_t, kQueriesPerBlock> m_prime;
516     cutlass::Array<accum_t, kQueriesPerBlock> s_prime;
517     cutlass::Array<accum_t, kQueriesPerBlock> mi;
518     cutlass::Array<accum_t, kQueriesPerBlock> out_rescale;
519     cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>
520         addition_storage;
521   };
522 
523   struct SharedStorageEpilogueAtEnd : ScalingCoefs {
524     struct SharedStorageAfterMM0 {
525       // Everything here might be overwritten during MM0
526       union {
527         typename MM0::BiasLoader::SmemTile bias;
528         typename MM0::AccumulatorSharedStorage si;
529       };
530       typename MM1::Mma::SharedStorage mm1;
531     };
532 
533     union {
534       typename MM0::Mma::SharedStorage mm0;
535       SharedStorageAfterMM0 after_mm0;
536       typename MM1::DefaultEpilogue::SharedStorage epilogue;
537     };
538 
539     CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
epilogue_shared_storageAttentionKernel::SharedStorageEpilogueAtEnd540     epilogue_shared_storage() {
541       return epilogue;
542     }
543   };
544 
545   struct SharedStorageEpilogueInLoop : ScalingCoefs {
546     struct SharedStorageAfterMM0 {
547       // Everything here might be overwritten during MM0
548       union {
549         typename MM0::BiasLoader::SmemTile bias;
550         typename MM0::AccumulatorSharedStorage si;
551       };
552       typename MM1::Mma::SharedStorage mm1;
553       typename MM1::DefaultEpilogue::SharedStorage epilogue;
554     };
555 
556     union {
557       typename MM0::Mma::SharedStorage mm0;
558       SharedStorageAfterMM0 after_mm0;
559     };
560 
561     CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
epilogue_shared_storageAttentionKernel::SharedStorageEpilogueInLoop562     epilogue_shared_storage() {
563       return after_mm0.epilogue;
564     }
565   };
566 
567   using SharedStorage = typename cutlass::platform::conditional<
568       kSingleValueIteration || kKeepOutputInRF,
569       SharedStorageEpilogueAtEnd,
570       SharedStorageEpilogueInLoop>::type;
571 
check_supportedAttentionKernel572   static bool __host__ check_supported(Params const& p) {
573     CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ);
574     CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK);
575     CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV);
576     if (kSupportsBias) {
577       CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ);
578       TORCH_CHECK(
579           p.num_batches <= 1 || p.bias_strideB % kAlignmentQ == 0,
580           "attn_bias is not correctly aligned (strideB). ",
581           "attn_bias.stride( 0) = ", p.bias_strideB, ", and should be a "
582           "multiple of ", kAlignmentQ, ".");
583       TORCH_CHECK(
584           p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0,
585           "attn_bias is not correctly aligned (strideH). "
586           "attn_bias.stride(1) = ", p.bias_strideH, ", and should be a "
587           "multiple of ", kAlignmentQ, ".");
588       TORCH_CHECK(
589           p.num_queries <= 1 || p.bias_strideM % kAlignmentQ == 0,
590           "attn_bias is not correctly aligned (strideM). "
591           "attn_bias.stride(2) = ", p.bias_strideM, ", and should be a "
592           "multiple of ", kAlignmentQ, ".");
593     }
594     TORCH_CHECK(
595         p.q_strideM % kAlignmentQ == 0,
596         "query is not correctly aligned (strideM)");
597     TORCH_CHECK(
598         p.k_strideM % kAlignmentK == 0,
599         "key is not correctly aligned (strideM)");
600     TORCH_CHECK(
601         p.v_strideM % kAlignmentV == 0,
602         "value is not correctly aligned (strideM)");
603     TORCH_CHECK(
604         p.num_heads <= 1 || p.q_strideH % kAlignmentQ == 0,
605         "query is not correctly aligned (strideH)");
606     TORCH_CHECK(
607         p.num_heads <= 1 || p.k_strideH % kAlignmentK == 0,
608         "key is not correctly aligned (strideH)");
609     TORCH_CHECK(
610         p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0,
611         "value is not correctly aligned (strideH)");
612     TORCH_CHECK(
613         p.custom_mask_type < NumCustomMaskTypes,
614         "invalid value for `custom_mask_type`");
615     if (p.window_size > 0) {
616       TORCH_CHECK(
617           p.custom_mask_type == CausalFromTopLeft ||
618               p.custom_mask_type == CausalFromBottomRight,
619           "custom_mask_type not supported");
620     }
621     return true;
622   }
623 
attention_kernelAttentionKernel624   static void CUTLASS_DEVICE attention_kernel(Params& p) {
625     // In this block, we will only ever:
626     // - read query[query_start:query_end, :]
627     // - write to output[query_start:query_end, :]
628 
629     extern __shared__ char smem_buffer[];
630     SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
631     auto& m_prime = shared_storage.m_prime;
632     auto& s_prime = shared_storage.s_prime;
633     auto& mi = shared_storage.mi;
634     auto& out_rescale = shared_storage.out_rescale;
635     const uint32_t query_start = blockIdx.x * kQueriesPerBlock;
636 
637     static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
638     if (thread_id() < kQueriesPerBlock) {
639       s_prime[thread_id()] = accum_t(0);
640       out_rescale[thread_id()] = accum_t(1.0);
641       m_prime[thread_id()] =
642           -cutlass::platform::numeric_limits<accum_t>::infinity();
643       mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
644     }
645     typename MM1::Mma::FragmentC accum_o;
646     accum_o.clear();
647 
648     auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
649       using OutputTileIterator = typename MM1::OutputTileIterator;
650       return OutputTileIterator(
651           typename OutputTileIterator::Params{(int32_t)p.o_strideM},
652           p.output_ptr,
653           typename OutputTileIterator::TensorCoord{
654               p.num_queries, p.head_dim_value},
655           thread_id(),
656           {0, col});
657     };
658 
659     auto createOutputAccumIter = [&](int col) ->
660         typename MM1::OutputTileIteratorAccum {
661           using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
662           return OutputTileIteratorAccum(
663               typename OutputTileIteratorAccum::Params{
664                   (int32_t)(p.head_dim_value * p.num_heads)},
665               p.output_accum_ptr,
666               typename OutputTileIteratorAccum::TensorCoord{
667                   p.num_queries, p.head_dim_value},
668               thread_id(),
669               {0, col});
670         };
671 
672     curandStatePhilox4_32_10_t curand_state_init;
673     if (kSupportsDropout && p.use_dropout) {
674       const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs);
675       if (p.rng_engine_inputs.captured_) {
676         // See Note [Seed and Offset Device]
677         // When we are in cuda graph capture mode the seed and offset are stored
678         // on device We pass in int64_t* seed, and int64_t* offset to act as
679         // scratch space for storing the rng state during the forward pass and
680         // saving for backwards.
681         auto [seed, offset] = seeds;
682         *p.seed = seed;
683         *p.extragraph_offset = offset;
684       }
685       // each element of the attention matrix P with shape
686       // (batch_sz, n_heads, n_queries, n_keys) is associated with a single
687       // offset in RNG sequence. we initialize the RNG state with offset that
688       // starts at the beginning of a (n_queries, n_keys) matrix for this
689       // block's batch_id and head_id
690       // initializing rng state is very expensive, so we run once per kernel,
691       // rather than once per iteration. each iteration takes a copy of the
692       // initialized RNG state and offsets it as needed.
693       curand_init(
694           std::get<0>(seeds),
695           0,
696           std::get<1>(seeds) + p.dropout_batch_head_rng_offset,
697           &curand_state_init);
698     }
699 
700     // Iterate through keys
701     for (int32_t iter_key_start = 0; iter_key_start < p.num_keys;
702          iter_key_start += kKeysPerBlock) {
703       if (p.window_size > 0) {
704         // don't compute anything if below attention band
705         if (iter_key_start + kKeysPerBlock <
706             int32_t(query_start + p.causal_diagonal_offset) - p.window_size) {
707           continue;
708         }
709       }
710       int32_t problem_size_0_m =
711           cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries);
712       int32_t problem_size_0_n = cutlass::fast_min(
713           int32_t(kKeysPerBlock), p.num_keys - iter_key_start);
714       int32_t const& problem_size_0_k = p.head_dim;
715       int32_t const& problem_size_1_n = p.head_dim_value;
716       int32_t const& problem_size_1_k = problem_size_0_n;
717 
718       auto prologueV = [&](int blockN) {
719         typename MM1::Mma::IteratorB iterator_V(
720             typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)},
721             const_cast<scalar_t*>(p.value_ptr + iter_key_start * p.v_strideM),
722             {problem_size_1_k, problem_size_1_n},
723             thread_id(),
724             cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
725         MM1::Mma::prologue(
726             shared_storage.after_mm0.mm1,
727             iterator_V,
728             thread_id(),
729             problem_size_1_k);
730       };
731 
732       __syncthreads(); // Need to have shared memory initialized, and `m_prime`
733                        // updated from end of prev iter
734       //
735       // MATMUL: Q.K_t
736       //
737       // Computes the block-matrix product of:
738       // (a) query[query_start:query_end, :]
739       // with
740       // (b) key[iter_key_start:iter_key_start + kKeysPerBlock]
741       // and stores that into `shared_storage.si`
742       //
743 
744       // Compute threadblock location
745       cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0};
746 
747       cutlass::MatrixCoord tb_offset_A{
748           tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()};
749 
750       cutlass::MatrixCoord tb_offset_B{
751           tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN};
752 
753       // Construct iterators to A and B operands
754       typename MM0::IteratorA iterator_A(
755           typename MM0::IteratorA::Params(
756               typename MM0::MmaCore::LayoutA(p.q_strideM)),
757           const_cast<scalar_t*>(p.query_ptr),
758           {problem_size_0_m, problem_size_0_k},
759           thread_id(),
760           tb_offset_A);
761 
762       typename MM0::IteratorB iterator_B(
763           typename MM0::IteratorB::Params(
764               typename MM0::MmaCore::LayoutB(p.k_strideM)),
765           const_cast<scalar_t*>(p.key_ptr + iter_key_start * p.k_strideM),
766           {problem_size_0_k, problem_size_0_n},
767           thread_id(),
768           tb_offset_B);
769 
770       auto my_warp_id = warp_uniform(warp_id());
771       auto my_lane_id = lane_id();
772 
773       // Construct thread-scoped matrix multiply
774       typename MM0::Mma mma(
775           shared_storage.mm0, thread_id(), my_warp_id, my_lane_id);
776 
777       typename MM0::Mma::FragmentC accum;
778 
779       accum.clear();
780 
781       auto gemm_k_iterations =
782           (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK;
783 
784       // Compute threadblock-scoped matrix multiply-add
785       mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
786       __syncthreads();
787 
788       if (kPreloadV) {
789         prologueV(0);
790       }
791 
792       typename MM0::Mma::Operator::IteratorC::TensorCoord
793           iteratorC_tile_offset = {
794               (tb_tile_offset.m() * MM0::Mma::WarpCount::kM) +
795                   (my_warp_id % MM0::Mma::WarpCount::kM),
796               (tb_tile_offset.n() * MM0::Mma::WarpCount::kN) +
797                   (my_warp_id / MM0::Mma::WarpCount::kM)};
798 
799       // multiply by scaling factor
800       if (kSupportsBias) {
801         accum =
802             cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.scale, accum);
803       }
804 
805       // apply attention bias if applicable
806       if (kSupportsBias && p.attn_bias_ptr != nullptr) {
807         // load bias tile Bij into shared memory
808         typename MM0::BiasLoader::GmemTileIterator bias_iter(
809             {cutlass::layout::RowMajor(p.bias_strideM)},
810             // attn_bias_pointer points to matrix of size (n_queries, n_keys)
811             // for the relevant batch_id and head_id
812             const_cast<scalar_t*>(p.attn_bias_ptr + query_start * p.bias_strideM + iter_key_start),
813             {problem_size_0_m, problem_size_0_n},
814             thread_id());
815         cutlass::TensorRef<scalar_t, cutlass::layout::RowMajor> bias_tensor_ref(
816             shared_storage.after_mm0.bias.data(),
817             cutlass::layout::RowMajor(MM0::ThreadblockShape::kN));
818         typename MM0::BiasLoader::SmemTileIterator smem_tile_iter(
819             bias_tensor_ref, thread_id());
820         MM0::BiasLoader::load(bias_iter, smem_tile_iter);
821 
822         // Pij += Bij, Pij is in register fragment and Bij is in shared memory
823         auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
824             my_lane_id, my_warp_id, iteratorC_tile_offset);
825         MM0::AccumLambdaIterator::iterateRows(
826             lane_offset,
827             [&](int accum_m) {},
828             [&](int accum_m, int accum_n, int idx) {
829               if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) {
830                 accum[idx] += bias_tensor_ref.at({accum_m, accum_n});
831               }
832             },
833             [&](int accum_m) {});
834       }
835 
836       // Mask out last if causal
837       // This is only needed if upper-right corner of current query / key block
838       // intersects the mask Coordinates of upper-right corner of current block
839       // is y=query_start x=min(iter_key_start + kKeysPerBlock, num_keys)) The
840       // first masked element is x = y + offset -> query_start + offset There is
841       // intersection (and we need to mask) if min(iter_key_start +
842       // kKeysPerBlock, num_keys)) >= query_start + offset
843       if (p.custom_mask_type &&
844           cutlass::fast_min(iter_key_start + kKeysPerBlock, p.num_keys) >=
845               (query_start + p.causal_diagonal_offset)) {
846         auto query_start = blockIdx.x * kQueriesPerBlock;
847         auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
848             my_lane_id, my_warp_id, iteratorC_tile_offset);
849         int32_t last_col;
850         MM0::AccumLambdaIterator::iterateRows(
851             lane_offset,
852             [&](int accum_m) {
853               // last absolute col is (last absolute query + offset)
854               // last local col is (last absolute query + offset -
855               // iter_key_start)
856               last_col = query_start + accum_m + p.causal_diagonal_offset -
857                   iter_key_start;
858             },
859             [&](int accum_m, int accum_n, int idx) {
860               if (accum_n > last_col) {
861                 accum[idx] =
862                     -cutlass::platform::numeric_limits<accum_t>::infinity();
863               }
864             },
865             [&](int accum_m) {});
866       }
867 
868       // Mask out lower left corner of block if window_size > 0
869       // only required if current block intersects with the lower left corner
870       // block starts at x_lowerleft = iter_key_start // y = query_start +
871       // kQueriesPerBlock first non masked value at this y is : x_first =
872       // query_start + kQueriesPerBlock - window_size mask if x_fist >
873       // x_lowerleft
874 
875       if (p.window_size > 0 &&
876           (query_start + p.causal_diagonal_offset +
877                cutlass::fast_min(
878                    int32_t(kQueriesPerBlock), int32_t(p.num_queries)) -
879                p.window_size >=
880            iter_key_start)) {
881         auto query_start = blockIdx.x * kQueriesPerBlock;
882         auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
883             my_lane_id, my_warp_id, iteratorC_tile_offset);
884         int32_t first_col;
885         const int32_t offset = query_start + p.causal_diagonal_offset -
886             p.window_size - iter_key_start;
887         MM0::AccumLambdaIterator::iterateRows(
888             lane_offset,
889             [&](int accum_m) { first_col = accum_m + offset; },
890             [&](int accum_m, int accum_n, int idx) {
891               if (accum_n <= first_col) {
892                 accum[idx] =
893                     -cutlass::platform::numeric_limits<accum_t>::infinity();
894               }
895             },
896             [&](int accum_m) {});
897         // print_warp_accum<MM0::AccumLambdaIterator>(accum, lane_offset, 12,
898         // 12);
899       }
900 
901       // Update `mi` from accum stored in registers
902       // Also does accum[i] <- exp(accum[i] - mi)
903       iterative_softmax<typename MM0::Mma::Operator::IteratorC>(
904           accum_o,
905           accum,
906           mi,
907           m_prime,
908           s_prime,
909           out_rescale,
910           shared_storage.addition_storage,
911           my_lane_id,
912           thread_id(),
913           my_warp_id,
914           p.num_keys - iter_key_start,
915           iter_key_start == 0,
916           iteratorC_tile_offset,
917           kSupportsBias ? 1.0f : p.scale);
918 
919       // Output results to shared-memory
920       int warp_idx_mn_0 = my_warp_id %
921           (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN);
922       auto output_tile_coords = cutlass::MatrixCoord{
923           warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM,
924           warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM};
925 
926       MM0::B2bGemm::accumToSmem(
927           shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords);
928 
929       __syncthreads();
930 
931       // apply dropout (if applicable) after we've written Pij to smem.
932       // dropout is applied by multiplying each element of Pij by:
933       // - 0 with probability dropout_p
934       // - 1 / (1 - dropout_p) with probability 1 - dropout_p
935       //
936       // for backward purposes we want to be able to map each element of the
937       // attention matrix to the same random uniform number as the one we used
938       // in forward, without needing to use the same iteration order or having
939       // to store the dropout matrix. its possible to do this in registers but
940       // it ends up being very slow because each thread having noncontiguous
941       // strips of the Pij tile means we have to skip around a lot, and also
942       // have to generate a single random number at a time
943       if (kSupportsDropout && p.use_dropout) {
944         auto si = shared_storage.after_mm0.si.accum_ref();
945         // each thread handles a contiguous sequence of elements from Sij, all
946         // coming from the same row. the reason they have to come from the same
947         // row is that the sampling random numbers from a contiguous random
948         // number sequence is much more efficient than jumping around, and the
949         // linear offset of each element of S (the global matrix) maps to an
950         // offset in a random number sequence. for S, the end of a row and the
951         // beginning of the next have adjacent offsets, but for Sij, this is not
952         // necessarily the case.
953         const int num_threads = blockDim.x * blockDim.y * blockDim.z;
954         const int threads_per_row =
955             cutlass::fast_min(num_threads / problem_size_0_m, problem_size_0_n);
956         const int elts_per_thread = cutlass::round_nearest(
957             cutlass::ceil_div(problem_size_0_n, threads_per_row), 4);
958 
959         const int thread_i = thread_id() / threads_per_row;
960         const int thread_start_j =
961             (thread_id() % threads_per_row) * elts_per_thread;
962 
963         if (thread_i < problem_size_0_m && thread_start_j < problem_size_0_n) {
964           curandStatePhilox4_32_10_t curand_state = curand_state_init;
965           skipahead(
966               static_cast<unsigned long long>(
967                   (query_start + thread_i) * p.num_keys_absolute +
968                   (iter_key_start + thread_start_j)),
969               &curand_state);
970           const float dropout_scale = 1.0 / (1.0 - p.dropout_prob);
971 
972           // apply dropout scaling to elements this thread is responsible for,
973           // in chunks of 4
974           for (int sij_start_col_idx = thread_start_j; sij_start_col_idx <
975                cutlass::fast_min(thread_start_j + elts_per_thread,
976                                  problem_size_0_n);
977                sij_start_col_idx += 4) {
978             const float4 rand_uniform_quad = curand_uniform4(&curand_state);
979 
980             CUTLASS_PRAGMA_UNROLL
981             for (int quad_idx = 0; quad_idx < 4; ++quad_idx) {
982               si.at({thread_i, sij_start_col_idx + quad_idx}) *=
983                   static_cast<scalar_t>(
984                       dropout_scale *
985                       ((&rand_uniform_quad.x)[quad_idx] > p.dropout_prob));
986             }
987           }
988         }
989         __syncthreads(); // p.use_dropout should have same value kernel-wide
990       }
991 
992       //
993       // MATMUL: Attn . V
994       // Run the matmul `attn @ V` for a block of attn and V.
995       // `attn` is read from shared memory (in `shared_storage_si`)
996       // `V` is read from global memory (with iterator_B)
997       //
998 
999       const int64_t nBlockN = kSingleValueIteration
1000           ? 1
1001           : ceil_div(
1002                 (int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN));
1003       for (int blockN = 0; blockN < nBlockN; ++blockN) {
1004         int gemm_k_iterations =
1005             (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK;
1006 
1007         // Compute threadblock-scoped matrix multiply-add and store it in accum
1008         // (in registers)
1009         if (!kPreloadV) {
1010           __syncthreads(); // we share shmem between mma and epilogue
1011         }
1012 
1013         typename MM1::Mma::IteratorB iterator_V(
1014             typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)},
1015             const_cast<scalar_t*>(p.value_ptr + iter_key_start * p.v_strideM),
1016             {problem_size_1_k, problem_size_1_n},
1017             thread_id(),
1018             cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
1019         typename MM1::Mma mma_pv(
1020             // operand A: Pij_dropped in shared memory
1021             shared_storage.after_mm0.si.accum_ref(),
1022             // operand B: shared memory staging area for Vj, which is loaded
1023             // from global memory
1024             shared_storage.after_mm0.mm1.operand_B_ref(),
1025             (int)thread_id(),
1026             (int)my_warp_id,
1027             (int)my_lane_id);
1028         mma_pv.set_prologue_done(kPreloadV);
1029         if (!kKeepOutputInRF) {
1030           accum_o.clear();
1031         }
1032         mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o);
1033         __syncthreads();
1034 
1035         if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) {
1036           prologueV(blockN + 1);
1037         }
1038 
1039         if (!kKeepOutputInRF) {
1040           int first_key = 0;
1041           if (p.window_size > 0) {
1042             first_key = (cutlass::fast_max(
1043                              int(query_start + p.causal_diagonal_offset) -
1044                                  p.window_size + 1,
1045                              0) /
1046                          kKeysPerBlock) *
1047                 kKeysPerBlock;
1048           }
1049 
1050           // int first_key_block = 0;
1051           // MM1::Mma::drain_cp_asyncs(); # TODO figure out if this is needed for correctness
1052           DISPATCH_BOOL(
1053               iter_key_start == first_key, kIsFirst, ([&] {
1054                 DISPATCH_BOOL(
1055                     (iter_key_start + kKeysPerBlock) >= p.num_keys,
1056                     kIsLast,
1057                     ([&] {
1058                       using DefaultEpilogue = typename MM1::DefaultEpilogue;
1059                       using DefaultOp =
1060                           typename MM1::DefaultConfig::EpilogueOutputOp;
1061                       using ElementCompute = typename DefaultOp::ElementCompute;
1062                       using EpilogueOutputOp = typename cutlass::epilogue::
1063                           thread::MemoryEfficientAttentionNormalize<
1064                               typename cutlass::platform::conditional<
1065                                   kIsLast,
1066                                   output_t,
1067                                   output_accum_t>::type,
1068                               output_accum_t,
1069                               DefaultOp::kCount,
1070                               typename DefaultOp::ElementAccumulator,
1071                               ElementCompute,
1072                               kIsFirst,
1073                               kIsLast,
1074                               cutlass::Array<ElementCompute, kQueriesPerBlock>>;
1075                       using Epilogue = typename cutlass::epilogue::threadblock::
1076                           EpiloguePipelined<
1077                               typename DefaultEpilogue::Shape,
1078                               typename MM1::Mma::Operator,
1079                               DefaultEpilogue::kPartitionsK,
1080                               typename cutlass::platform::conditional<
1081                                   kIsLast,
1082                                   typename MM1::OutputTileIterator,
1083                                   typename MM1::OutputTileIteratorAccum>::type,
1084                               typename DefaultEpilogue::
1085                                   AccumulatorFragmentIterator,
1086                               typename DefaultEpilogue::WarpTileIterator,
1087                               typename DefaultEpilogue::SharedLoadIterator,
1088                               EpilogueOutputOp,
1089                               typename DefaultEpilogue::Padding,
1090                               DefaultEpilogue::kFragmentsPerIteration,
1091                               true, // IterationsUnroll
1092                               typename MM1::OutputTileIteratorAccum // Read
1093                                                                     // iterator
1094                               >;
1095 
1096                       int col = blockN * MM1::Mma::Shape::kN;
1097                       auto source_iter = createOutputAccumIter(col);
1098                       auto dest_iter = call_conditional<
1099                           kIsLast,
1100                           decltype(createOutputIter),
1101                           decltype(createOutputAccumIter)>::
1102                           apply(createOutputIter, createOutputAccumIter, col);
1103                       EpilogueOutputOp rescale(s_prime, out_rescale);
1104                       Epilogue epilogue(
1105                           shared_storage.epilogue_shared_storage(),
1106                           thread_id(),
1107                           my_warp_id,
1108                           my_lane_id);
1109                       epilogue(rescale, dest_iter, accum_o, source_iter);
1110                     }));
1111               }));
1112           if (!kSingleValueIteration) {
1113             __syncthreads();
1114           }
1115         }
1116       }
1117       __syncthreads(); // we modify `m_prime` after
1118     }
1119 
1120     if (kKeepOutputInRF) {
1121       constexpr bool kIsFirst = true;
1122       constexpr bool kIsLast = true;
1123       using DefaultEpilogue = typename MM1::DefaultEpilogue;
1124       using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
1125       using ElementCompute = typename DefaultOp::ElementCompute;
1126       using EpilogueOutputOp =
1127           typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize<
1128               output_t, // output
1129               output_accum_t, // source
1130               DefaultOp::kCount,
1131               typename DefaultOp::ElementAccumulator, // accum
1132               output_accum_t, // compute
1133               kIsFirst,
1134               kIsLast,
1135               cutlass::Array<ElementCompute, kQueriesPerBlock>>;
1136       using Epilogue =
1137           typename cutlass::epilogue::threadblock::EpiloguePipelined<
1138               typename DefaultEpilogue::Shape,
1139               typename MM1::Mma::Operator,
1140               DefaultEpilogue::kPartitionsK,
1141               typename MM1::OutputTileIterator, // destination
1142               typename DefaultEpilogue::AccumulatorFragmentIterator,
1143               typename DefaultEpilogue::WarpTileIterator,
1144               typename DefaultEpilogue::SharedLoadIterator,
1145               EpilogueOutputOp,
1146               typename DefaultEpilogue::Padding,
1147               DefaultEpilogue::kFragmentsPerIteration,
1148               true, // IterationsUnroll
1149               typename MM1::OutputTileIteratorAccum // source tile
1150               >;
1151       auto dest_iter = createOutputIter(0);
1152       EpilogueOutputOp rescale(s_prime, out_rescale);
1153       Epilogue epilogue(
1154           shared_storage.epilogue_shared_storage(),
1155           thread_id(),
1156           warp_id(),
1157           lane_id());
1158       epilogue(rescale, dest_iter, accum_o);
1159     }
1160 
1161     // 7. Calculate logsumexp
1162     // To make the backward easier, we pad logsumexp with `inf`
1163     // this avoids a few bound checks, and is not more expensive during fwd
1164     static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
1165     if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
1166       auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
1167       constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
1168       if (thread_id() < p.num_queries) {
1169         // We set fully masked out rows to 0, the sumexp for masked out rows will be 0
1170         // We update it to be 1 prior to calling log so that log(1) = 0
1171         s_prime[thread_id()] = (s_prime[thread_id()] == 0) ? 1: s_prime[thread_id()];
1172         mi[thread_id()] = (mi[thread_id()] == -cutlass::platform::numeric_limits<accum_t>::infinity()) ? 0: mi[thread_id()];
1173         p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) +
1174             cutlass::fast_log(accum_t(s_prime[thread_id()]));
1175       } else if (thread_id() < lse_dim) {
1176         p.logsumexp_ptr[thread_id()] =
1177             cutlass::platform::numeric_limits<accum_t>::infinity();
1178       }
1179     }
1180   }
1181 
1182   template <typename WarpIteratorC>
iterative_softmaxAttentionKernel1183   CUTLASS_DEVICE static void iterative_softmax(
1184       typename WarpIteratorC::Fragment& frag_o, // output so far
1185       typename WarpIteratorC::Fragment& frag,
1186       cutlass::Array<accum_t, kQueriesPerBlock>& mi,
1187       cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
1188       cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
1189       cutlass::Array<accum_t, kQueriesPerBlock>& out_rescale,
1190       cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>&
1191           addition_storage,
1192       int8_t lane_id,
1193       int8_t thread_id,
1194       int8_t warp_id,
1195       int max_col,
1196       bool is_first,
1197       typename WarpIteratorC::TensorCoord const& tile_offset,
1198       float scaling) {
1199     /* Iterates on the accumulator and corresponding position on result matrix
1200 
1201     (1) Update `mi[r]` to the max value of the row `r`
1202     (2) In a second iteration do the following:
1203         (a) accum   <- exp(accum - mi)
1204         (b) m_prime <- exp(m_prime - mi)
1205         (c) s_prime <- s_prime * m_prime + sum(accum)
1206 
1207     All of this is done on registers, before we store all of this
1208     on shared memory for the next matmul with Value.
1209     */
1210     using Fragment = typename WarpIteratorC::Fragment;
1211     using LambdaIterator = typename DefaultMmaAccumLambdaIterator<
1212         WarpIteratorC,
1213         accum_t,
1214         kWarpSize>::Iterator;
1215     // Convert to `accum_t` (rather than double)
1216     constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
1217 
1218     static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, "");
1219     static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock;
1220 
1221     frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
1222 
1223     auto lane_offset =
1224         LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
1225 
1226     // First update `mi` to the max per-row
1227     {
1228       accum_t max;
1229       LambdaIterator::iterateRows(
1230           lane_offset,
1231           [&](int accum_m) {
1232             max = -cutlass::platform::numeric_limits<accum_t>::infinity();
1233           },
1234           [&](int accum_m, int accum_n, int idx) {
1235             if (accum_n < max_col) {
1236               max = cutlass::fast_max(max, frag[idx]);
1237             }
1238           },
1239           [&](int accum_m) {
1240             // Having 4x atomicMax seems faster than reduce within warp
1241             // first...
1242             atomicMaxFloat(&mi[accum_m], max);
1243           });
1244     }
1245 
1246     // Make sure we all share the update values for `mi`
1247     __syncthreads();
1248 
1249     // Doing this `exp` is quite expensive. Let's
1250     // split it across the warps
1251     bool restore_mi_to_minus_inf = false;
1252     if (lane_id < kLinesPerWarp) {
1253       int id = warp_id * kLinesPerWarp + lane_id;
1254       auto m_prime_id = m_prime[id];
1255       auto mi_id = mi[id];
1256       bool changed = m_prime_id < mi_id; // `false` if both are -inf
1257       if (changed) {
1258         auto m_prime_exp = exp2f(m_prime_id - mi_id);
1259         out_rescale[id] = m_prime_exp;
1260         s_prime[id] *= m_prime_exp;
1261       } else {
1262         // Only when bias is enabled, it's possible that all the first values
1263         // of attention are masked to `-inf`. In that case we want to avoid
1264         // `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0
1265         if (kSupportsBias &&
1266             mi_id == -cutlass::platform::numeric_limits<accum_t>::infinity()) {
1267           restore_mi_to_minus_inf = true;
1268           mi[id] = 0.0f;
1269         }
1270         out_rescale[id] = 1.0f;
1271       }
1272     }
1273     __syncthreads(); // Update output fragments
1274     if (kKeepOutputInRF && !is_first) {
1275       accum_t line_rescale;
1276       LambdaIterator::iterateRows(
1277           lane_offset,
1278           [&](int accum_m) { line_rescale = out_rescale[accum_m]; },
1279           [&](int accum_m, int accum_n, int idx) {
1280             frag_o[idx] = frag_o[idx] * line_rescale;
1281           },
1282           [&](int accum_m) {});
1283     }
1284     // Update accum_m, accum_n, ...
1285     {
1286       accum_t mi_row, total_row;
1287       LambdaIterator::iterateRows(
1288           lane_offset,
1289           [&](int accum_m) { mi_row = mi[accum_m]; },
1290           [&](int accum_m, int accum_n, int idx) {
1291             frag[idx] =
1292                 (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
1293           },
1294           [&](int accum_m) {});
1295       LambdaIterator::iterateRows(
1296           lane_offset,
1297           [&](int accum_m) { total_row = 0.0; },
1298           [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; },
1299           [&](int accum_m) {
1300             if (LambdaIterator::reduceSameRow(
1301                     lane_id, total_row, [](accum_t a, accum_t b) {
1302                       return a + b;
1303                     })) {
1304               // NOTE: we could atomically add `total_row` to `s_prime`, but
1305               // it's faster (and deterministic) to avoid atomics here
1306               addition_storage
1307                   [accum_m + kQueriesPerBlock * tile_offset.column()] =
1308                       total_row;
1309             }
1310           });
1311     }
1312     __syncthreads();
1313     if (lane_id < kLinesPerWarp) {
1314       int id = warp_id * kLinesPerWarp + lane_id;
1315       accum_t total_row = s_prime[id];
1316       if (restore_mi_to_minus_inf) {
1317         // Restore `mi`, see above when we set `restore_mi_to_minus_inf=true`
1318         mi[id] = -cutlass::platform::numeric_limits<accum_t>::infinity();
1319       } else {
1320         m_prime[id] = mi[id];
1321       }
1322       CUTLASS_PRAGMA_UNROLL
1323       for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
1324         total_row += addition_storage[id + kQueriesPerBlock * i];
1325       }
1326       s_prime[id] = total_row;
1327     }
1328   }
1329 
lane_idAttentionKernel1330   static CUTLASS_DEVICE int8_t lane_id() {
1331     return threadIdx.x;
1332   }
warp_idAttentionKernel1333   static CUTLASS_DEVICE int8_t warp_id() {
1334     return threadIdx.y;
1335   }
thread_idAttentionKernel1336   static CUTLASS_DEVICE int16_t thread_id() {
1337     return threadIdx.x + threadIdx.y * blockDim.x;
1338   }
1339 };
1340 
1341 template <typename AK>
__launch_bounds__(AK::kNumThreads,AK::kMinBlocksPerSm)1342 __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
1343     attention_kernel_batched_impl(typename AK::Params p) {
1344   if (!p.advance_to_block()) {
1345     return;
1346   }
1347   AK::attention_kernel(p);
1348 }
1349 
1350 template <typename AK>
1351 __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
1352     attention_kernel_batched(typename AK::Params params);
1353 
1354 } // namespace PyTorchMemEffAttention
1355