1 #pragma once 2 #include <cstddef> 3 4 #include <ATen/core/Tensor.h> 5 #include <c10/util/Exception.h> 6 7 namespace pytorch_flash { 8 9 TORCH_API 10 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> 11 mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size 12 const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size 13 const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size 14 std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size 15 std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads 16 const float p_dropout, 17 const float softmax_scale, 18 bool is_causal, 19 int window_size_left, 20 int window_size_right, 21 const bool return_softmax, 22 std::optional<at::Generator> gen_); 23 24 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> 25 mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i 26 const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i 27 const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i 28 std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i 29 const at::Tensor &cu_seqlens_q, // b+1 30 const at::Tensor &cu_seqlens_k, // b+1 31 std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used. 32 std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq 33 std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads 34 int max_seqlen_q, 35 const int max_seqlen_k, 36 const float p_dropout, 37 const float softmax_scale, 38 const bool zero_tensors, 39 bool is_causal, 40 int window_size_left, 41 int window_size_right, 42 const bool return_softmax, 43 std::optional<at::Generator> gen_); 44 45 46 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> 47 mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og 48 const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size 49 const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size 50 const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size 51 const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size 52 const at::Tensor &softmax_lse, // b x h x seqlen_q 53 std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size 54 std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size 55 std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size 56 std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads 57 const float p_dropout, // probability to drop 58 const float softmax_scale, 59 const bool is_causal, 60 int window_size_left, 61 int window_size_right, 62 const bool deterministic, 63 const at::Tensor philox_seed, 64 const at::Tensor philox_offset); 65 66 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> 67 mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size 68 const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i 69 const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i 70 const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i 71 const at::Tensor &out, // total_q x num_heads x head_size 72 const at::Tensor &softmax_lse, // b x h x s softmax logsumexp 73 std::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i 74 std::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i 75 std::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i 76 const at::Tensor &cu_seqlens_q, // b+1 77 const at::Tensor &cu_seqlens_k, // b+1 78 std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads 79 const int max_seqlen_q, 80 const int max_seqlen_k, // max sequence length to choose the kernel 81 const float p_dropout, // probability to drop 82 const float softmax_scale, 83 const bool zero_tensors, 84 const bool is_causal, 85 int window_size_left, 86 int window_size_right, 87 const bool deterministic, 88 const at::Tensor philox_seed, 89 const at::Tensor philox_offset); 90 91 } // namespace pytorch_flash 92