1 /****************************************************************************** 2 * Copyright (c) 2023, Tri Dao. 3 ******************************************************************************/ 4 5 #pragma once 6 7 #include <cuda.h> 8 9 #ifdef OLD_GENERATOR_PATH 10 #include <ATen/CUDAGeneratorImpl.h> 11 #else 12 #include <ATen/cuda/CUDAGeneratorImpl.h> 13 #endif 14 15 #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack 16 namespace pytorch_flash { 17 constexpr int TOTAL_DIM = 0; 18 constexpr int H_DIM = 1; 19 constexpr int D_DIM = 2; 20 21 //////////////////////////////////////////////////////////////////////////////////////////////////// 22 23 struct Qkv_params { 24 using index_t = int64_t; 25 // The QKV matrices. 26 void *__restrict__ q_ptr; 27 void *__restrict__ k_ptr; 28 void *__restrict__ v_ptr; 29 30 // The stride between rows of the Q, K and V matrices. 31 index_t q_batch_stride; 32 index_t k_batch_stride; 33 index_t v_batch_stride; 34 index_t q_row_stride; 35 index_t k_row_stride; 36 index_t v_row_stride; 37 index_t q_head_stride; 38 index_t k_head_stride; 39 index_t v_head_stride; 40 41 // The number of heads. 42 int h, h_k; 43 // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be 44 // different from nheads (query). 45 int h_h_k_ratio; // precompute h / h_k, 46 }; 47 48 //////////////////////////////////////////////////////////////////////////////////////////////////// 49 50 struct Flash_fwd_params : public Qkv_params { 51 52 // The O matrix (output). 53 void * __restrict__ o_ptr; 54 void * __restrict__ oaccum_ptr; 55 56 // The stride between rows of O. 57 index_t o_batch_stride; 58 index_t o_row_stride; 59 index_t o_head_stride; 60 61 // The pointer to the P matrix. 62 void * __restrict__ p_ptr; 63 64 // The pointer to the softmax sum. 65 void * __restrict__ softmax_lse_ptr; 66 void * __restrict__ softmax_lseaccum_ptr; 67 68 // The dimensions. 69 int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; 70 71 // The scaling factors for the kernel. 72 float scale_softmax; 73 float scale_softmax_log2; 74 75 // array of length b+1 holding starting offset of each sequence. 76 int * __restrict__ cu_seqlens_q; 77 int * __restrict__ cu_seqlens_k; 78 79 // If provided, the actual length of each k sequence. 80 int * __restrict__ seqused_k; 81 82 int *__restrict__ blockmask; 83 84 // The K_new and V_new matrices. 85 void * __restrict__ knew_ptr; 86 void * __restrict__ vnew_ptr; 87 88 // The stride between rows of the Q, K and V matrices. 89 index_t knew_batch_stride; 90 index_t vnew_batch_stride; 91 index_t knew_row_stride; 92 index_t vnew_row_stride; 93 index_t knew_head_stride; 94 index_t vnew_head_stride; 95 96 // The cos and sin matrices for rotary embedding. 97 void * __restrict__ rotary_cos_ptr; 98 void * __restrict__ rotary_sin_ptr; 99 100 // The indices to index into the KV cache. 101 int * __restrict__ cache_batch_idx; 102 103 // Paged KV cache 104 int * __restrict__ block_table; 105 index_t block_table_batch_stride; 106 int page_block_size; 107 108 // The dropout probability (probability of keeping an activation). 109 float p_dropout; 110 // uint32_t p_dropout_in_uint; 111 // uint16_t p_dropout_in_uint16_t; 112 uint8_t p_dropout_in_uint8_t; 113 114 // Scale factor of 1 / (1 - p_dropout). 115 float rp_dropout; 116 float scale_softmax_rp_dropout; 117 118 // Local window size 119 int window_size_left, window_size_right; 120 121 // Random state. 122 at::PhiloxCudaState philox_args; 123 int64_t * extragraph_offset; 124 int64_t * seed; 125 126 bool is_bf16; 127 bool is_causal; 128 129 // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. 130 // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. 131 bool is_seqlens_k_cumulative; 132 133 bool is_rotary_interleaved; 134 135 int num_splits; // For split-KV version 136 137 void * __restrict__ alibi_slopes_ptr; 138 index_t alibi_slopes_batch_stride; 139 }; 140 141 //////////////////////////////////////////////////////////////////////////////////////////////////// 142 143 struct Flash_bwd_params : public Flash_fwd_params { 144 145 // The dO and dQKV matrices. 146 void *__restrict__ do_ptr; 147 void *__restrict__ dq_ptr; 148 void *__restrict__ dk_ptr; 149 void *__restrict__ dv_ptr; 150 151 // To accumulate dQ 152 void *__restrict__ dq_accum_ptr; 153 void *__restrict__ dk_accum_ptr; 154 void *__restrict__ dv_accum_ptr; 155 156 // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q 157 // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ 158 // dv_accum_ptr; 159 160 // The stride between rows of the dO, dQ, dK and dV matrices. 161 // TD [2022-04-16]: We're using 32-bit indexing to save registers. 162 // The code probably won't work for arrays larger than 2GB. 163 index_t do_batch_stride; 164 index_t do_row_stride; 165 index_t do_head_stride; 166 index_t dq_batch_stride; 167 index_t dk_batch_stride; 168 index_t dv_batch_stride; 169 index_t dq_row_stride; 170 index_t dk_row_stride; 171 index_t dv_row_stride; 172 index_t dq_head_stride; 173 index_t dk_head_stride; 174 index_t dv_head_stride; 175 176 // The pointer to the softmax d sum. 177 void *__restrict__ dsoftmax_sum; 178 179 bool deterministic; 180 index_t dq_accum_split_stride; 181 }; 182 183 //////////////////////////////////////////////////////////////////////////////////////////////////// 184 185 template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); 186 template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 187 188 template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); 189 190 } // namespace pytorch_flash 191