1 #pragma once 2 #include <ATen/core/Tensor.h> 3 4 namespace at { 5 namespace native { 6 7 void run_cudnn_SDP_fprop( 8 int64_t b, 9 int64_t h, 10 int64_t s_q, 11 int64_t s_kv, 12 int64_t d_k, 13 int64_t d_v, 14 float scaling_factor, 15 bool isTraining, 16 bool is_causal, 17 double dropout_probability, 18 const Tensor& q, 19 const Tensor& k, 20 const Tensor& v, 21 const std::optional<Tensor>& attn_bias, 22 Tensor& softmaxstats, 23 Tensor& o, 24 Tensor& dropoutseed, 25 Tensor& dropoutoffset); 26 27 void run_cudnn_SDP_bprop( 28 int64_t b, 29 int64_t h, 30 int64_t s_q, 31 int64_t s_kv, 32 int64_t d_k, 33 int64_t d_v, 34 float scaling_factor, 35 bool is_causal, 36 float dropout_probability, 37 const Tensor& q, 38 const Tensor& k, 39 const Tensor& v, 40 const std::optional<Tensor>& attn_bias, 41 const Tensor& o, 42 const Tensor& dO, 43 const Tensor& softmaxstats, 44 Tensor& dQ, 45 Tensor& dK, 46 Tensor& dV, 47 const Tensor& dropoutseed, 48 const Tensor& dropoutoffset); 49 50 } // namespace native 51 } // namespace at 52