xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/flash.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /******************************************************************************
2  * Copyright (c) 2023, Tri Dao.
3  ******************************************************************************/
4 
5 #pragma once
6 
7 #include <cuda.h>
8 
9 #ifdef OLD_GENERATOR_PATH
10 #include <ATen/CUDAGeneratorImpl.h>
11 #else
12 #include <ATen/cuda/CUDAGeneratorImpl.h>
13 #endif
14 
15 #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
16 namespace pytorch_flash {
17 constexpr int TOTAL_DIM = 0;
18 constexpr int H_DIM = 1;
19 constexpr int D_DIM = 2;
20 
21 ////////////////////////////////////////////////////////////////////////////////////////////////////
22 
23 struct Qkv_params {
24     using index_t = int64_t;
25     // The QKV matrices.
26     void *__restrict__ q_ptr;
27     void *__restrict__ k_ptr;
28     void *__restrict__ v_ptr;
29 
30     // The stride between rows of the Q, K and V matrices.
31     index_t q_batch_stride;
32     index_t k_batch_stride;
33     index_t v_batch_stride;
34     index_t q_row_stride;
35     index_t k_row_stride;
36     index_t v_row_stride;
37     index_t q_head_stride;
38     index_t k_head_stride;
39     index_t v_head_stride;
40 
41     // The number of heads.
42     int h, h_k;
43     // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
44     // different from nheads (query).
45     int h_h_k_ratio; // precompute h / h_k,
46 };
47 
48 ////////////////////////////////////////////////////////////////////////////////////////////////////
49 
50 struct Flash_fwd_params : public Qkv_params {
51 
52     // The O matrix (output).
53     void * __restrict__ o_ptr;
54     void * __restrict__ oaccum_ptr;
55 
56     // The stride between rows of O.
57     index_t o_batch_stride;
58     index_t o_row_stride;
59     index_t o_head_stride;
60 
61     // The pointer to the P matrix.
62     void * __restrict__ p_ptr;
63 
64     // The pointer to the softmax sum.
65     void * __restrict__ softmax_lse_ptr;
66     void * __restrict__ softmax_lseaccum_ptr;
67 
68     // The dimensions.
69     int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
70 
71     // The scaling factors for the kernel.
72     float scale_softmax;
73     float scale_softmax_log2;
74 
75     // array of length b+1 holding starting offset of each sequence.
76     int * __restrict__ cu_seqlens_q;
77     int * __restrict__ cu_seqlens_k;
78 
79     // If provided, the actual length of each k sequence.
80     int * __restrict__ seqused_k;
81 
82     int *__restrict__ blockmask;
83 
84     // The K_new and V_new matrices.
85     void * __restrict__ knew_ptr;
86     void * __restrict__ vnew_ptr;
87 
88     // The stride between rows of the Q, K and V matrices.
89     index_t knew_batch_stride;
90     index_t vnew_batch_stride;
91     index_t knew_row_stride;
92     index_t vnew_row_stride;
93     index_t knew_head_stride;
94     index_t vnew_head_stride;
95 
96     // The cos and sin matrices for rotary embedding.
97     void * __restrict__ rotary_cos_ptr;
98     void * __restrict__ rotary_sin_ptr;
99 
100     // The indices to index into the KV cache.
101     int * __restrict__ cache_batch_idx;
102 
103     // Paged KV cache
104     int * __restrict__ block_table;
105     index_t block_table_batch_stride;
106     int page_block_size;
107 
108     // The dropout probability (probability of keeping an activation).
109     float p_dropout;
110     // uint32_t p_dropout_in_uint;
111     // uint16_t p_dropout_in_uint16_t;
112     uint8_t p_dropout_in_uint8_t;
113 
114     // Scale factor of 1 / (1 - p_dropout).
115     float rp_dropout;
116     float scale_softmax_rp_dropout;
117 
118     // Local window size
119     int window_size_left, window_size_right;
120 
121     // Random state.
122     at::PhiloxCudaState philox_args;
123     int64_t * extragraph_offset;
124     int64_t * seed;
125 
126     bool is_bf16;
127     bool is_causal;
128 
129     // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
130     // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
131     bool is_seqlens_k_cumulative;
132 
133     bool is_rotary_interleaved;
134 
135     int num_splits;  // For split-KV version
136 
137     void * __restrict__ alibi_slopes_ptr;
138     index_t alibi_slopes_batch_stride;
139 };
140 
141 ////////////////////////////////////////////////////////////////////////////////////////////////////
142 
143 struct Flash_bwd_params : public Flash_fwd_params {
144 
145     // The dO and dQKV matrices.
146     void *__restrict__ do_ptr;
147     void *__restrict__ dq_ptr;
148     void *__restrict__ dk_ptr;
149     void *__restrict__ dv_ptr;
150 
151     // To accumulate dQ
152     void *__restrict__ dq_accum_ptr;
153     void *__restrict__ dk_accum_ptr;
154     void *__restrict__ dv_accum_ptr;
155 
156     // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
157     // dimension void *__restrict__ dk_accum_ptr; void *__restrict__
158     // dv_accum_ptr;
159 
160     // The stride between rows of the dO, dQ, dK and dV matrices.
161     // TD [2022-04-16]: We're using 32-bit indexing to save registers.
162     // The code probably won't work for arrays larger than 2GB.
163     index_t do_batch_stride;
164     index_t do_row_stride;
165     index_t do_head_stride;
166     index_t dq_batch_stride;
167     index_t dk_batch_stride;
168     index_t dv_batch_stride;
169     index_t dq_row_stride;
170     index_t dk_row_stride;
171     index_t dv_row_stride;
172     index_t dq_head_stride;
173     index_t dk_head_stride;
174     index_t dv_head_stride;
175 
176     // The pointer to the softmax d sum.
177     void *__restrict__ dsoftmax_sum;
178 
179     bool deterministic;
180     index_t dq_accum_split_stride;
181 };
182 
183 ////////////////////////////////////////////////////////////////////////////////////////////////////
184 
185 template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
186 template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
187 
188 template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
189 
190 } // namespace pytorch_flash
191