xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <cstddef>
3 
4 #include <ATen/core/Tensor.h>
5 #include <c10/util/Exception.h>
6 
7 namespace pytorch_flash {
8 
9 TORCH_API
10 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
11 mha_fwd(const at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
12         const at::Tensor &k,         // batch_size x seqlen_k x num_heads_k x head_size
13         const at::Tensor &v,         // batch_size x seqlen_k x num_heads_k x head_size
14         std::optional<at::Tensor> &out_,             // batch_size x seqlen_q x num_heads x head_size
15         std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
16         const float p_dropout,
17         const float softmax_scale,
18         bool is_causal,
19         int window_size_left,
20         int window_size_right,
21         const bool return_softmax,
22         std::optional<at::Generator> gen_);
23 
24 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
25 mha_varlen_fwd(const at::Tensor &q,  // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
26                const at::Tensor &k,  // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
27                const at::Tensor &v,  // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
28                std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
29                const at::Tensor &cu_seqlens_q,  // b+1
30                const at::Tensor &cu_seqlens_k,  // b+1
31                std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
32                std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
33                std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
34                int max_seqlen_q,
35                const int max_seqlen_k,
36                const float p_dropout,
37                const float softmax_scale,
38                const bool zero_tensors,
39                bool is_causal,
40                int window_size_left,
41                int window_size_right,
42                const bool return_softmax,
43                std::optional<at::Generator> gen_);
44 
45 
46 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
47 mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_size_og
48         const at::Tensor &q,   // batch_size x seqlen_q x num_heads x head_size
49         const at::Tensor &k,   // batch_size x seqlen_k x num_heads_k x head_size
50         const at::Tensor &v,   // batch_size x seqlen_k x num_heads_k x head_size
51         const at::Tensor &out,   // batch_size x seqlen_q x num_heads x head_size
52         const at::Tensor &softmax_lse,     // b x h x seqlen_q
53         std::optional<at::Tensor> &dq_,   // batch_size x seqlen_q x num_heads x head_size
54         std::optional<at::Tensor> &dk_,   // batch_size x seqlen_k x num_heads_k x head_size
55         std::optional<at::Tensor> &dv_,   // batch_size x seqlen_k x num_heads_k x head_size
56         std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
57         const float p_dropout,         // probability to drop
58         const float softmax_scale,
59         const bool is_causal,
60         int window_size_left,
61         int window_size_right,
62         const bool deterministic,
63         const at::Tensor philox_seed,
64         const at::Tensor philox_offset);
65 
66 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
67 mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
68                const at::Tensor &q,   // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
69                const at::Tensor &k,   // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
70                const at::Tensor &v,   // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
71                const at::Tensor &out,   // total_q x num_heads x head_size
72                const at::Tensor &softmax_lse,     // b x h x s   softmax logsumexp
73                std::optional<at::Tensor> &dq_,   // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
74                std::optional<at::Tensor> &dk_,   // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
75                std::optional<at::Tensor> &dv_,   // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
76                const at::Tensor &cu_seqlens_q,  // b+1
77                const at::Tensor &cu_seqlens_k,  // b+1
78                std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
79                const int max_seqlen_q,
80                const int max_seqlen_k,          // max sequence length to choose the kernel
81                const float p_dropout,         // probability to drop
82                const float softmax_scale,
83                const bool zero_tensors,
84                const bool is_causal,
85                int window_size_left,
86                int window_size_right,
87                const bool deterministic,
88                const at::Tensor philox_seed,
89                const at::Tensor philox_offset);
90 
91 } // namespace pytorch_flash
92