xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cudnn/MHA.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/Config.h>
3 #include <ATen/cuda/CUDAConfig.h>
4 
5 #if defined(USE_ROCM) || !AT_CUDNN_ENABLED() || \
6     (defined(CUDNN_VERSION) && CUDNN_VERSION < 8900)
7 
8 namespace at {
9 namespace native {
10 
run_cudnn_SDP_fprop(int64_t b,int64_t h,int64_t s_q,int64_t s_kv,int64_t d_qk,int64_t d_v,float scaling_factor,bool isTraining,bool is_causal,double dropout_probability,const Tensor & q,const Tensor & k,const Tensor & v,Tensor & softmaxstats,Tensor & o,Tensor & dropoutseed,Tensor & dropoutoffset)11 void run_cudnn_SDP_fprop(
12     int64_t b,
13     int64_t h,
14     int64_t s_q,
15     int64_t s_kv,
16     int64_t d_qk,
17     int64_t d_v,
18     float scaling_factor,
19     bool isTraining,
20     bool is_causal,
21     double dropout_probability,
22     const Tensor& q,
23     const Tensor& k,
24     const Tensor& v,
25     Tensor& softmaxstats,
26     Tensor& o,
27     Tensor& dropoutseed,
28     Tensor& dropoutoffset) {
29   TORCH_CHECK(
30       false, "PyTorch was not compiled with cuDNN Flash Attention enabled!");
31 }
32 
run_cudnn_SDP_bprop(int64_t b,int64_t h,int64_t s_q,int64_t s_kv,int64_t d_qk,int64_t d_v,float scaling_factor,bool is_causal,float dropout_probability,const Tensor & q,const Tensor & k,const Tensor & v,const Tensor & o,const Tensor & dO,const Tensor & softmaxstats,Tensor & dQ,Tensor & dK,Tensor & dV,const Tensor & dropoutseed,const Tensor & dropoutoffset)33 void run_cudnn_SDP_bprop(
34     int64_t b,
35     int64_t h,
36     int64_t s_q,
37     int64_t s_kv,
38     int64_t d_qk,
39     int64_t d_v,
40     float scaling_factor,
41     bool is_causal,
42     float dropout_probability,
43     const Tensor& q,
44     const Tensor& k,
45     const Tensor& v,
46     const Tensor& o,
47     const Tensor& dO,
48     const Tensor& softmaxstats,
49     Tensor& dQ,
50     Tensor& dK,
51     Tensor& dV,
52     const Tensor& dropoutseed,
53     const Tensor& dropoutoffset) {
54   TORCH_CHECK(
55       false, "PyTorch was not compiled with cuDNN Flash Attention enabled!");
56 }
57 
58 } // namespace native
59 } // namespace at
60 
61 #else // AT_CUDNN_ENABLED && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8900
62 #include <ATen/cudnn/Descriptors.h>
63 #include <ATen/cudnn/Types.h>
64 #include <ATen/cudnn/Utils.h>
65 #include <ATen/native/cudnn/MHA.h>
66 
67 #include <ATen/cuda/Exceptions.h>
68 #include <cudnn_frontend.h>
69 
70 #include <ATen/TensorUtils.h>
71 #include <ATen/native/utils/ParamsHash.h>
72 
73 #include <c10/cuda/CUDACachingAllocator.h>
74 #include <cudnn.h>
75 
76 #include <iostream>
77 
78 namespace at {
79 namespace native {
80 
81 #include <cudnn_frontend.h>
82 
83 namespace fe = cudnn_frontend;
84 using graph_and_tensors = std::tuple<
85     std::shared_ptr<fe::graph::Graph>,
86     std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
87     std::shared_ptr<fe::graph::Tensor_attributes>, // K,
88     std::shared_ptr<fe::graph::Tensor_attributes>, // V,
89     std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale,
90     // TODO(eqy): additional options
91     // std::shared_ptr<fe::graph::Tensor_attributes>, // Bias,
92     // std::shared_ptr<fe::graph::Tensor_attributes>, // SEQ_LEN_Q,
93     // std::shared_ptr<fe::graph::Tensor_attributes>, // SEQ_LEN_KV,
94     std::shared_ptr<fe::graph::Tensor_attributes>, // Seed,
95     std::shared_ptr<fe::graph::Tensor_attributes>, // Offset,
96     // std::shared_ptr<fe::graph::Tensor_attributes>, // Dropout_mask,
97     // std::shared_ptr<fe::graph::Tensor_attributes>, // Dropout_scale
98     std::shared_ptr<fe::graph::Tensor_attributes>, // O
99     std::shared_ptr<fe::graph::Tensor_attributes> // Stats
100     >;
101 
102 using graph_and_tensors_backward = std::tuple<
103     std::shared_ptr<fe::graph::Graph>,
104     std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
105     std::shared_ptr<fe::graph::Tensor_attributes>, // K,
106     std::shared_ptr<fe::graph::Tensor_attributes>, // V,
107     std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale
108     std::shared_ptr<fe::graph::Tensor_attributes>, // Seed,
109     std::shared_ptr<fe::graph::Tensor_attributes>, // Offset,
110     std::shared_ptr<fe::graph::Tensor_attributes>, // O,
111     std::shared_ptr<fe::graph::Tensor_attributes>, // dO,
112     std::shared_ptr<fe::graph::Tensor_attributes>, // stats,
113     std::shared_ptr<fe::graph::Tensor_attributes>, // dQ,
114     std::shared_ptr<fe::graph::Tensor_attributes>, // dK,,
115     std::shared_ptr<fe::graph::Tensor_attributes> // dV,
116     >;
117 
118 #define MAX_MHA_DIM 4
119 
120 struct MHAParams {
121   c10::DeviceIndex device_id;
122   fe::DataType_t dataType;
123   std::array<int, MAX_MHA_DIM> q_dim;
124   std::array<int, MAX_MHA_DIM> k_dim;
125   std::array<int, MAX_MHA_DIM> v_dim;
126   std::array<int, MAX_MHA_DIM> q_stride;
127   std::array<int, MAX_MHA_DIM> k_stride;
128   std::array<int, MAX_MHA_DIM> v_stride;
129   int64_t b;
130   int64_t h;
131   int64_t s_q;
132   int64_t s_kv;
133   int64_t d_qk;
134   int64_t d_v;
135   double dropout_probability;
136   bool is_causal;
137   bool return_softmaxstats;
138 };
139 
setMHAParams(MHAParams & params,int64_t b,int64_t h,int64_t s_q,int64_t s_kv,int64_t d_qk,int64_t d_v,const Tensor & q,const Tensor & k,const Tensor & v,double dropout_probability,bool is_causal,bool return_softmaxstats)140 void setMHAParams(
141     MHAParams& params,
142     int64_t b,
143     int64_t h,
144     int64_t s_q,
145     int64_t s_kv,
146     int64_t d_qk,
147     int64_t d_v,
148     const Tensor& q,
149     const Tensor& k,
150     const Tensor& v,
151     double dropout_probability,
152     bool is_causal,
153     bool return_softmaxstats) {
154   memset(&params, 0, sizeof(MHAParams));
155   params.device_id = at::cuda::current_device();
156   params.dataType = fe::DataType_t::HALF;
157   if (q.scalar_type() == kBFloat16) {
158     params.dataType = fe::DataType_t::BFLOAT16;
159   }
160   params.b = b;
161   params.h = h;
162   params.d_qk = d_qk;
163   params.d_v = d_v;
164   params.s_q = s_q;
165   params.s_kv = s_kv;
166   params.dropout_probability = dropout_probability;
167   params.is_causal = is_causal;
168   params.return_softmaxstats = return_softmaxstats;
169   TORCH_INTERNAL_ASSERT(
170       q.sizes().size() == MAX_MHA_DIM,
171       "Q tensor has unexpected number of dims, please report a bug to PyTorch.");
172   TORCH_INTERNAL_ASSERT(
173       q.strides().size() == MAX_MHA_DIM,
174       "Q tensor has unexpected number of dims, please report a bug to PyTorch.");
175   TORCH_INTERNAL_ASSERT(
176       k.sizes().size() == MAX_MHA_DIM,
177       "K tensor has unexpected number of dims, please report a bug to PyTorch.");
178   TORCH_INTERNAL_ASSERT(
179       k.strides().size() == MAX_MHA_DIM,
180       "K tensor has unexpected number of dims, please report a bug to PyTorch.");
181   TORCH_INTERNAL_ASSERT(
182       v.sizes().size() == MAX_MHA_DIM,
183       "V tensor has unexpected number of dims, please report a bug to PyTorch.");
184   TORCH_INTERNAL_ASSERT(
185       v.strides().size() == MAX_MHA_DIM,
186       "V tensor has unexpected number of dims, please report a bug to PyTorch.");
187   std::copy(q.sizes().begin(), q.sizes().end(), params.q_dim.begin());
188   std::copy(q.strides().begin(), q.strides().end(), params.q_stride.begin());
189   std::copy(k.sizes().begin(), k.sizes().end(), params.k_dim.begin());
190   std::copy(k.strides().begin(), k.strides().end(), params.k_stride.begin());
191   std::copy(v.sizes().begin(), v.sizes().end(), params.v_dim.begin());
192   std::copy(v.strides().begin(), v.strides().end(), params.v_stride.begin());
193 }
194 
195 struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
MHACacheKeyWrapperat::native::MHACacheKeyWrapper196   MHACacheKeyWrapper(
197       int64_t b,
198       int64_t h,
199       int64_t s_q,
200       int64_t s_kv,
201       int64_t d_qk,
202       int64_t d_v,
203       const Tensor& q,
204       const Tensor& k,
205       const Tensor& v,
206       double dropout_probability,
207       bool is_causal,
208       bool return_softmaxstats) {
209     setMHAParams(
210         this->pod,
211         b,
212         h,
213         s_q,
214         s_kv,
215         d_qk,
216         d_v,
217         q,
218         k,
219         v,
220         dropout_probability,
221         is_causal,
222         return_softmaxstats);
223   }
224 };
225 
226 template <typename T, typename KeyType>
227 struct MHAGraphCache {
228   std::unordered_map<KeyType, T, ParamsWrapperHash<KeyType>> engine_cache;
229 
230   // no mutexes here as caches are now thread local for v8, can also return a
231   // pointer to the Execution Plan if we know it will not be invalidated by
232   // another thread
findat::native::MHAGraphCache233   T* find(const KeyType& key) {
234     auto it = engine_cache.find(key);
235     if (it == engine_cache.end()) {
236       return nullptr;
237     }
238     return &(it->second);
239   }
240 
updateat::native::MHAGraphCache241   void update(const KeyType& key, T& results) {
242     engine_cache.erase(key);
243     engine_cache.emplace(key, std::move(results));
244   }
245 };
246 
247 // @eqy: use thread local caches as cuDNN Execution Plans are not guaranteed to
248 // be thread safe across all engines see Limitations in
249 // https://docs.nvidia.com/deeplearning/cudnn/release-notes/index.html
250 thread_local MHAGraphCache<graph_and_tensors, MHACacheKeyWrapper> mhagraphcache;
251 thread_local MHAGraphCache<graph_and_tensors_backward, MHACacheKeyWrapper>
252     mhagraphbackwardcache;
253 
build_graph_and_tensors(int64_t b,int64_t h,int64_t s_q,int64_t s_kv,int64_t d_qk,int64_t d_v,float scaling_factor,bool return_softmaxstats,bool is_causal,double dropout_probability,const Tensor & q,const Tensor & k,const Tensor & v,Tensor & softmaxstats,Tensor & o,Tensor & dropoutseed,Tensor & dropoutoffset,cudnnHandle_t & handle,MHAParams & params)254 auto build_graph_and_tensors(
255     int64_t b,
256     int64_t h,
257     int64_t s_q,
258     int64_t s_kv,
259     int64_t d_qk,
260     int64_t d_v,
261     float scaling_factor,
262     bool return_softmaxstats,
263     bool is_causal,
264     double dropout_probability,
265     const Tensor& q,
266     const Tensor& k,
267     const Tensor& v,
268     Tensor& softmaxstats,
269     Tensor& o,
270     Tensor& dropoutseed,
271     Tensor& dropoutoffset,
272     cudnnHandle_t& handle,
273     MHAParams& params) {
274   auto dtype = fe::DataType_t::HALF;
275   if (q.scalar_type() == kBFloat16) {
276     dtype = fe::DataType_t::BFLOAT16;
277   }
278   auto mha_graph = std::make_shared<fe::graph::Graph>();
279   // We're baking in float accumulation and scale types
280   // in theory the graph may support other types, but they
281   // have not been tested
282   mha_graph->set_io_data_type(dtype)
283       .set_intermediate_data_type(fe::DataType_t::FLOAT)
284       .set_compute_data_type(fe::DataType_t::FLOAT);
285   auto Q = mha_graph->tensor(
286       fe::graph::Tensor_attributes()
287           .set_name("Q")
288           .set_dim(
289               std::vector<int64_t>(params.q_dim.begin(), params.q_dim.end()))
290           .set_stride(std::vector<int64_t>(
291               params.q_stride.begin(), params.q_stride.end())));
292   auto K = mha_graph->tensor(
293       fe::graph::Tensor_attributes()
294           .set_name("K")
295           .set_dim(
296               std::vector<int64_t>(params.k_dim.begin(), params.k_dim.end()))
297           .set_stride(std::vector<int64_t>(
298               params.k_stride.begin(), params.k_stride.end())));
299   auto V = mha_graph->tensor(
300       fe::graph::Tensor_attributes()
301           .set_name("V")
302           .set_dim(
303               std::vector<int64_t>(params.v_dim.begin(), params.v_dim.end()))
304           .set_stride(std::vector<int64_t>(
305               params.v_stride.begin(), params.v_stride.end())));
306   auto attn_scale =
307       mha_graph->tensor(fe::graph::Tensor_attributes()
308                             .set_name("Attn_scale")
309                             .set_dim({1, 1, 1, 1})
310                             .set_stride({1, 1, 1, 1})
311                             .set_is_pass_by_value(true)
312                             .set_data_type(fe::DataType_t::FLOAT));
313   // TODO(eqy): support bias in the future in a follow-up PR
314   // auto bias = mha_graph->tensor(fe::graph::Tensor_attributes()
315   //                         .set_name("bias")
316   //                         .set_dim({b, 1, s_q, s_kv})
317   //                         .set_stride({s_q * s_kv, s_q * s_kv, s_kv, 1}));
318   auto seed = mha_graph->tensor(fe::graph::Tensor_attributes()
319                                     .set_name("Seed")
320                                     .set_dim({1, 1, 1, 1})
321                                     .set_stride({1, 1, 1, 1})
322                                     .set_data_type(fe::DataType_t::INT32));
323   auto offset = mha_graph->tensor(fe::graph::Tensor_attributes()
324                                       .set_name("Offset")
325                                       .set_dim({1, 1, 1, 1})
326                                       .set_stride({1, 1, 1, 1})
327                                       .set_data_type(fe::DataType_t::INT32));
328   auto scaled_dot_product_flash_attention_options =
329       fe::graph::SDPA_attributes()
330           .set_name("CUDNN_SDPA")
331           .set_is_inference(return_softmaxstats == false)
332           .set_causal_mask(is_causal)
333           .set_attn_scale(attn_scale)
334           .set_dropout(dropout_probability, seed, offset);
335   // Optional bias in flash attention is only supported 8.9.3 onwards
336   if (cudnnGetVersion() >= 8904) {
337     // scaled_dot_product_flash_attention_options.set_alibi_mask(true);
338   }
339 
340   auto seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
341                                      .set_name("Seq_q")
342                                      .set_dim({b, 1, 1, 1})
343                                      .set_stride({1, 1, 1, 1})
344                                      .set_data_type(fe::DataType_t::INT32));
345   auto seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes()
346                                       .set_name("Seq_kv")
347                                       .set_dim({b, 1, 1, 1})
348                                       .set_stride({1, 1, 1, 1})
349                                       .set_data_type(fe::DataType_t::INT32));
350 
351   // if (cudnnGetVersion() >= 8903) {
352   //     scaled_dot_product_flash_attention_options.set_bias(bias)
353   //         .set_padding_mask(true)
354   //         .set_seq_len_q(seq_q)
355   //         .set_seq_len_kv(seq_kv);
356   // }
357 
358   auto [O, Stats] =
359       mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options);
360   O->set_output(true)
361       .set_dim(std::vector<int64_t>(
362           o.sizes().data(), o.sizes().data() + o.sizes().size()))
363       .set_stride(std::vector<int64_t>(
364           o.strides().data(), o.strides().data() + o.strides().size()));
365 
366   if (Stats) {
367     Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT);
368   }
369 
370   AT_CUDNN_FRONTEND_CHECK(mha_graph->validate());
371   AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle));
372   AT_CUDNN_FRONTEND_CHECK(
373       mha_graph->create_execution_plans({fe::HeurMode_t::A}));
374   AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle));
375   AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle));
376 
377   return std::make_tuple(
378       std::move(mha_graph),
379       std::move(Q),
380       std::move(K),
381       std::move(V),
382       std::move(attn_scale),
383       std::move(seed),
384       std::move(offset),
385       std::move(O),
386       std::move(Stats));
387 }
388 
build_graph_and_tensors_backward(int64_t b,int64_t h,int64_t s_q,int64_t s_kv,int64_t d_qk,int64_t d_v,float scaling_factor,bool is_causal,float dropout_probability,const Tensor & q,const Tensor & k,const Tensor & v,const Tensor & o,const Tensor & dO,const Tensor & softmaxstats,Tensor & dQ,Tensor & dK,Tensor & dV,const Tensor & dropoutseed,const Tensor & dropoutoffset,cudnnHandle_t & handle,MHAParams & params)389 auto build_graph_and_tensors_backward(
390     int64_t b,
391     int64_t h,
392     int64_t s_q,
393     int64_t s_kv,
394     int64_t d_qk,
395     int64_t d_v,
396     float scaling_factor,
397     bool is_causal,
398     float dropout_probability,
399     const Tensor& q,
400     const Tensor& k,
401     const Tensor& v,
402     const Tensor& o,
403     const Tensor& dO,
404     const Tensor& softmaxstats,
405     Tensor& dQ,
406     Tensor& dK,
407     Tensor& dV,
408     const Tensor& dropoutseed,
409     const Tensor& dropoutoffset,
410     cudnnHandle_t& handle,
411     MHAParams& params) {
412   auto dtype = fe::DataType_t::HALF;
413   if (q.scalar_type() == kBFloat16) {
414     dtype = fe::DataType_t::BFLOAT16;
415   }
416   auto mha_graph = std::make_shared<fe::graph::Graph>();
417   // We're baking in float accumulation and scale types
418   // in theory the graph may support other types, but they
419   // have not been tested
420   mha_graph->set_io_data_type(dtype)
421       .set_intermediate_data_type(fe::DataType_t::FLOAT)
422       .set_compute_data_type(fe::DataType_t::FLOAT);
423   auto Q = mha_graph->tensor(
424       fe::graph::Tensor_attributes()
425           .set_name("Q")
426           .set_dim(std::vector<int64_t>(q.sizes().begin(), q.sizes().end()))
427           .set_stride(
428               std::vector<int64_t>(q.strides().begin(), q.strides().end())));
429   auto K = mha_graph->tensor(
430       fe::graph::Tensor_attributes()
431           .set_name("K")
432           .set_dim(std::vector<int64_t>(k.sizes().begin(), k.sizes().end()))
433           .set_stride(
434               std::vector<int64_t>(k.strides().begin(), k.strides().end())));
435   auto V = mha_graph->tensor(
436       fe::graph::Tensor_attributes()
437           .set_name("V")
438           .set_dim(std::vector<int64_t>(v.sizes().begin(), v.sizes().end()))
439           .set_stride(
440               std::vector<int64_t>(v.strides().begin(), v.strides().end())));
441   auto attn_scale =
442       mha_graph->tensor(fe::graph::Tensor_attributes()
443                             .set_name("Attn_scale")
444                             .set_dim({1, 1, 1, 1})
445                             .set_stride({1, 1, 1, 1})
446                             .set_is_pass_by_value(true)
447                             .set_data_type(fe::DataType_t::FLOAT));
448   auto Seed = mha_graph->tensor(fe::graph::Tensor_attributes()
449                                     .set_name("Seed")
450                                     .set_dim({1, 1, 1, 1})
451                                     .set_stride({1, 1, 1, 1})
452                                     .set_data_type(fe::DataType_t::INT32));
453   auto Offset = mha_graph->tensor(fe::graph::Tensor_attributes()
454                                       .set_name("Offset")
455                                       .set_dim({1, 1, 1, 1})
456                                       .set_stride({1, 1, 1, 1})
457                                       .set_data_type(fe::DataType_t::INT32));
458   auto O = mha_graph->tensor(
459       fe::graph::Tensor_attributes()
460           .set_name("O")
461           .set_dim(std::vector<int64_t>(o.sizes().begin(), o.sizes().end()))
462           .set_stride(
463               std::vector<int64_t>(o.strides().begin(), o.strides().end())));
464   auto STATS = mha_graph->tensor(
465       fe::graph::Tensor_attributes()
466           .set_name("Stats")
467           .set_dim(std::vector<int64_t>(
468               softmaxstats.sizes().begin(), softmaxstats.sizes().end()))
469           .set_stride(std::vector<int64_t>(
470               softmaxstats.strides().begin(), softmaxstats.strides().end()))
471           .set_data_type(fe::DataType_t::FLOAT));
472   auto DO = mha_graph->tensor(
473       fe::graph::Tensor_attributes()
474           .set_name("DO")
475           .set_dim(std::vector<int64_t>(dO.sizes().begin(), dO.sizes().end()))
476           .set_stride(
477               std::vector<int64_t>(dO.strides().begin(), dO.strides().end())));
478   auto sdpa_backward_options = fe::graph::SDPA_backward_attributes()
479                                    .set_name("CUDNN_SDPA_BACKWARD")
480                                    .set_causal_mask(is_causal)
481                                    .set_attn_scale(attn_scale);
482   if (dropout_probability != 0.0f) {
483     sdpa_backward_options.set_dropout(dropout_probability, Seed, Offset);
484   }
485   auto [DQ, DK, DV] =
486       mha_graph->sdpa_backward(Q, K, V, O, DO, STATS, sdpa_backward_options);
487   DQ->set_output(true)
488       .set_dim(std::vector<int64_t>(dQ.sizes().begin(), dQ.sizes().end()))
489       .set_stride(
490           std::vector<int64_t>(dQ.strides().begin(), dQ.strides().end()));
491   DK->set_output(true)
492       .set_dim(std::vector<int64_t>(dK.sizes().begin(), dK.sizes().end()))
493       .set_stride(
494           std::vector<int64_t>(dK.strides().begin(), dK.strides().end()));
495   DV->set_output(true)
496       .set_dim(std::vector<int64_t>(dV.sizes().begin(), dV.sizes().end()))
497       .set_stride(
498           std::vector<int64_t>(dV.strides().begin(), dV.strides().end()));
499   AT_CUDNN_FRONTEND_CHECK(mha_graph->validate());
500   AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle));
501   AT_CUDNN_FRONTEND_CHECK(
502       mha_graph->create_execution_plans({fe::HeurMode_t::A}));
503   AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle));
504   AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle));
505   return std::make_tuple(
506       std::move(mha_graph),
507       std::move(Q),
508       std::move(K),
509       std::move(V),
510       std::move(attn_scale),
511       std::move(Seed),
512       std::move(Offset),
513       std::move(O),
514       std::move(DO),
515       std::move(STATS),
516       std::move(DQ),
517       std::move(DK),
518       std::move(DV));
519 }
520 
run_cudnn_SDP_fprop(int64_t b,int64_t h,int64_t s_q,int64_t s_kv,int64_t d_qk,int64_t d_v,float scaling_factor,bool return_softmaxstats,bool is_causal,double dropout_probability,const Tensor & q,const Tensor & k,const Tensor & v,Tensor & softmaxstats,Tensor & o,Tensor & dropoutseed,Tensor & dropoutoffset)521 void run_cudnn_SDP_fprop(
522     int64_t b,
523     int64_t h,
524     int64_t s_q,
525     int64_t s_kv,
526     int64_t d_qk,
527     int64_t d_v,
528     float scaling_factor,
529     bool return_softmaxstats,
530     bool is_causal,
531     double dropout_probability,
532     const Tensor& q,
533     const Tensor& k,
534     const Tensor& v,
535     Tensor& softmaxstats,
536     Tensor& o,
537     Tensor& dropoutseed,
538     Tensor& dropoutoffset) {
539   cudnnHandle_t handle = getCudnnHandle();
540   o = at::empty_strided(
541       {b, h, s_q, d_v}, {s_q * h * d_v, d_v, h * d_v, 1}, q.options());
542   if (return_softmaxstats) {
543     // TODO(eqy): verify that this is correct
544     softmaxstats = at::empty({b, h, s_q}, q.options().dtype(kFloat));
545   }
546 
547   auto key = MHACacheKeyWrapper(
548       b,
549       h,
550       s_q,
551       s_kv,
552       d_qk,
553       d_v,
554       q,
555       k,
556       v,
557       dropout_probability,
558       is_causal,
559       return_softmaxstats);
560   auto graph_and_tensors_ptr = mhagraphcache.find(key);
561   graph_and_tensors graph_and_tensors_values;
562   if (graph_and_tensors_ptr) {
563     graph_and_tensors_values = *graph_and_tensors_ptr;
564   } else {
565     graph_and_tensors_values = build_graph_and_tensors(
566         b,
567         h,
568         s_q,
569         s_kv,
570         d_qk,
571         d_v,
572         scaling_factor,
573         return_softmaxstats,
574         is_causal,
575         dropout_probability,
576         q,
577         k,
578         v,
579         softmaxstats,
580         o,
581         dropoutseed,
582         dropoutoffset,
583         handle,
584         key.pod);
585   }
586   auto [mha_graph, Q, K, V, attn_scale, seed, offset, O, Stats] =
587       graph_and_tensors_values;
588   std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*>
589       variant_pack = {
590           {Q, q.data_ptr()},
591           {K, k.data_ptr()},
592           {V, v.data_ptr()},
593           {attn_scale, &scaling_factor},
594           //{bias, bias.data_ptr()},
595           {seed, dropoutseed.data_ptr()},
596           {offset, dropoutoffset.data_ptr()},
597           {O, o.data_ptr()}};
598   if (return_softmaxstats) {
599     variant_pack[Stats] = softmaxstats.data_ptr();
600   }
601   auto workspace_size = mha_graph->get_workspace_size();
602   auto workspace_ptr =
603       c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
604   TORCH_CHECK(
605       mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
606   mhagraphcache.update(key, graph_and_tensors_values);
607 }
608 
run_cudnn_SDP_bprop(int64_t b,int64_t h,int64_t s_q,int64_t s_kv,int64_t d_qk,int64_t d_v,float scaling_factor,bool is_causal,float dropout_probability,const Tensor & q,const Tensor & k,const Tensor & v,const Tensor & o,const Tensor & dO,const Tensor & softmaxstats,Tensor & dQ,Tensor & dK,Tensor & dV,const Tensor & dropoutseed,const Tensor & dropoutoffset)609 void run_cudnn_SDP_bprop(
610     int64_t b,
611     int64_t h,
612     int64_t s_q,
613     int64_t s_kv,
614     int64_t d_qk,
615     int64_t d_v,
616     float scaling_factor,
617     bool is_causal,
618     float dropout_probability,
619     const Tensor& q,
620     const Tensor& k,
621     const Tensor& v,
622     const Tensor& o,
623     const Tensor& dO,
624     const Tensor& softmaxstats,
625     Tensor& dQ,
626     Tensor& dK,
627     Tensor& dV,
628     const Tensor& dropoutseed,
629     const Tensor& dropoutoffset) {
630   cudnnHandle_t handle = getCudnnHandle();
631   auto key = MHACacheKeyWrapper(
632       b,
633       h,
634       s_q,
635       s_kv,
636       d_qk,
637       d_v,
638       q,
639       k,
640       v,
641       dropout_probability,
642       is_causal,
643       true);
644   auto graph_and_tensors_backward_ptr = mhagraphbackwardcache.find(key);
645   graph_and_tensors_backward graph_and_tensors_backward_values;
646   if (graph_and_tensors_backward_ptr) {
647     graph_and_tensors_backward_values = *graph_and_tensors_backward_ptr;
648   } else {
649     graph_and_tensors_backward_values = build_graph_and_tensors_backward(
650         b,
651         h,
652         s_q,
653         s_kv,
654         d_qk,
655         d_v,
656         scaling_factor,
657         is_causal,
658         dropout_probability,
659         q,
660         k,
661         v,
662         o,
663         dO,
664         softmaxstats,
665         dQ,
666         dK,
667         dV,
668         dropoutseed,
669         dropoutoffset,
670         handle,
671         key.pod);
672   }
673   auto
674       [mha_graph, Q, K, V, attn_scale, Seed, Offset, O, Do, Stats, Dq, Dk, Dv] =
675           graph_and_tensors_backward_values;
676   std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*>
677       variant_pack = {// inputs
678                       {Q, q.data_ptr()},
679                       {K, k.data_ptr()},
680                       {V, v.data_ptr()},
681                       {O, o.data_ptr()},
682                       {Do, dO.data_ptr()},
683                       {Stats, softmaxstats.data_ptr()},
684                       // outputs
685                       {Dq, dQ.data_ptr()},
686                       {Dk, dK.data_ptr()},
687                       {Dv, dV.data_ptr()},
688                       // pass by value
689                       {attn_scale, &scaling_factor}};
690   if (dropout_probability != 0.0f) {
691     variant_pack[Seed] = dropoutseed.data_ptr();
692     variant_pack[Offset] = dropoutoffset.data_ptr();
693   }
694   auto workspace_size = mha_graph->get_workspace_size();
695   auto workspace_ptr =
696       c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
697   TORCH_CHECK(!workspace_size || workspace_ptr.get());
698   TORCH_CHECK(
699       mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
700   mhagraphbackwardcache.update(key, graph_and_tensors_backward_values);
701 }
702 
703 } // namespace native
704 } // namespace at
705 #endif
706