xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cudnn/MHA.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 
4 namespace at {
5 namespace native {
6 
7 void run_cudnn_SDP_fprop(
8     int64_t b,
9     int64_t h,
10     int64_t s_q,
11     int64_t s_kv,
12     int64_t d_k,
13     int64_t d_v,
14     float scaling_factor,
15     bool isTraining,
16     bool is_causal,
17     double dropout_probability,
18     const Tensor& q,
19     const Tensor& k,
20     const Tensor& v,
21     const std::optional<Tensor>& attn_bias,
22     Tensor& softmaxstats,
23     Tensor& o,
24     Tensor& dropoutseed,
25     Tensor& dropoutoffset);
26 
27 void run_cudnn_SDP_bprop(
28     int64_t b,
29     int64_t h,
30     int64_t s_q,
31     int64_t s_kv,
32     int64_t d_k,
33     int64_t d_v,
34     float scaling_factor,
35     bool is_causal,
36     float dropout_probability,
37     const Tensor& q,
38     const Tensor& k,
39     const Tensor& v,
40     const std::optional<Tensor>& attn_bias,
41     const Tensor& o,
42     const Tensor& dO,
43     const Tensor& softmaxstats,
44     Tensor& dQ,
45     Tensor& dK,
46     Tensor& dV,
47     const Tensor& dropoutseed,
48     const Tensor& dropoutoffset);
49 
50 } // namespace native
51 } // namespace at
52