1 #include <ATen/ATen.h>
2 #include <ATen/Config.h>
3 #include <ATen/cuda/CUDAConfig.h>
4
5 #if defined(USE_ROCM) || !AT_CUDNN_ENABLED() || \
6 (defined(CUDNN_VERSION) && CUDNN_VERSION < 8900)
7
8 namespace at {
9 namespace native {
10
run_cudnn_SDP_fprop(int64_t b,int64_t h,int64_t s_q,int64_t s_kv,int64_t d_qk,int64_t d_v,float scaling_factor,bool isTraining,bool is_causal,double dropout_probability,const Tensor & q,const Tensor & k,const Tensor & v,Tensor & softmaxstats,Tensor & o,Tensor & dropoutseed,Tensor & dropoutoffset)11 void run_cudnn_SDP_fprop(
12 int64_t b,
13 int64_t h,
14 int64_t s_q,
15 int64_t s_kv,
16 int64_t d_qk,
17 int64_t d_v,
18 float scaling_factor,
19 bool isTraining,
20 bool is_causal,
21 double dropout_probability,
22 const Tensor& q,
23 const Tensor& k,
24 const Tensor& v,
25 Tensor& softmaxstats,
26 Tensor& o,
27 Tensor& dropoutseed,
28 Tensor& dropoutoffset) {
29 TORCH_CHECK(
30 false, "PyTorch was not compiled with cuDNN Flash Attention enabled!");
31 }
32
run_cudnn_SDP_bprop(int64_t b,int64_t h,int64_t s_q,int64_t s_kv,int64_t d_qk,int64_t d_v,float scaling_factor,bool is_causal,float dropout_probability,const Tensor & q,const Tensor & k,const Tensor & v,const Tensor & o,const Tensor & dO,const Tensor & softmaxstats,Tensor & dQ,Tensor & dK,Tensor & dV,const Tensor & dropoutseed,const Tensor & dropoutoffset)33 void run_cudnn_SDP_bprop(
34 int64_t b,
35 int64_t h,
36 int64_t s_q,
37 int64_t s_kv,
38 int64_t d_qk,
39 int64_t d_v,
40 float scaling_factor,
41 bool is_causal,
42 float dropout_probability,
43 const Tensor& q,
44 const Tensor& k,
45 const Tensor& v,
46 const Tensor& o,
47 const Tensor& dO,
48 const Tensor& softmaxstats,
49 Tensor& dQ,
50 Tensor& dK,
51 Tensor& dV,
52 const Tensor& dropoutseed,
53 const Tensor& dropoutoffset) {
54 TORCH_CHECK(
55 false, "PyTorch was not compiled with cuDNN Flash Attention enabled!");
56 }
57
58 } // namespace native
59 } // namespace at
60
61 #else // AT_CUDNN_ENABLED && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8900
62 #include <ATen/cudnn/Descriptors.h>
63 #include <ATen/cudnn/Types.h>
64 #include <ATen/cudnn/Utils.h>
65 #include <ATen/native/cudnn/MHA.h>
66
67 #include <ATen/cuda/Exceptions.h>
68 #include <cudnn_frontend.h>
69
70 #include <ATen/TensorUtils.h>
71 #include <ATen/native/utils/ParamsHash.h>
72
73 #include <c10/cuda/CUDACachingAllocator.h>
74 #include <cudnn.h>
75
76 #include <iostream>
77
78 namespace at {
79 namespace native {
80
81 #include <cudnn_frontend.h>
82
83 namespace fe = cudnn_frontend;
84 using graph_and_tensors = std::tuple<
85 std::shared_ptr<fe::graph::Graph>,
86 std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
87 std::shared_ptr<fe::graph::Tensor_attributes>, // K,
88 std::shared_ptr<fe::graph::Tensor_attributes>, // V,
89 std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale,
90 // TODO(eqy): additional options
91 // std::shared_ptr<fe::graph::Tensor_attributes>, // Bias,
92 // std::shared_ptr<fe::graph::Tensor_attributes>, // SEQ_LEN_Q,
93 // std::shared_ptr<fe::graph::Tensor_attributes>, // SEQ_LEN_KV,
94 std::shared_ptr<fe::graph::Tensor_attributes>, // Seed,
95 std::shared_ptr<fe::graph::Tensor_attributes>, // Offset,
96 // std::shared_ptr<fe::graph::Tensor_attributes>, // Dropout_mask,
97 // std::shared_ptr<fe::graph::Tensor_attributes>, // Dropout_scale
98 std::shared_ptr<fe::graph::Tensor_attributes>, // O
99 std::shared_ptr<fe::graph::Tensor_attributes> // Stats
100 >;
101
102 using graph_and_tensors_backward = std::tuple<
103 std::shared_ptr<fe::graph::Graph>,
104 std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
105 std::shared_ptr<fe::graph::Tensor_attributes>, // K,
106 std::shared_ptr<fe::graph::Tensor_attributes>, // V,
107 std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale
108 std::shared_ptr<fe::graph::Tensor_attributes>, // Seed,
109 std::shared_ptr<fe::graph::Tensor_attributes>, // Offset,
110 std::shared_ptr<fe::graph::Tensor_attributes>, // O,
111 std::shared_ptr<fe::graph::Tensor_attributes>, // dO,
112 std::shared_ptr<fe::graph::Tensor_attributes>, // stats,
113 std::shared_ptr<fe::graph::Tensor_attributes>, // dQ,
114 std::shared_ptr<fe::graph::Tensor_attributes>, // dK,,
115 std::shared_ptr<fe::graph::Tensor_attributes> // dV,
116 >;
117
118 #define MAX_MHA_DIM 4
119
120 struct MHAParams {
121 c10::DeviceIndex device_id;
122 fe::DataType_t dataType;
123 std::array<int, MAX_MHA_DIM> q_dim;
124 std::array<int, MAX_MHA_DIM> k_dim;
125 std::array<int, MAX_MHA_DIM> v_dim;
126 std::array<int, MAX_MHA_DIM> q_stride;
127 std::array<int, MAX_MHA_DIM> k_stride;
128 std::array<int, MAX_MHA_DIM> v_stride;
129 int64_t b;
130 int64_t h;
131 int64_t s_q;
132 int64_t s_kv;
133 int64_t d_qk;
134 int64_t d_v;
135 double dropout_probability;
136 bool is_causal;
137 bool return_softmaxstats;
138 };
139
setMHAParams(MHAParams & params,int64_t b,int64_t h,int64_t s_q,int64_t s_kv,int64_t d_qk,int64_t d_v,const Tensor & q,const Tensor & k,const Tensor & v,double dropout_probability,bool is_causal,bool return_softmaxstats)140 void setMHAParams(
141 MHAParams& params,
142 int64_t b,
143 int64_t h,
144 int64_t s_q,
145 int64_t s_kv,
146 int64_t d_qk,
147 int64_t d_v,
148 const Tensor& q,
149 const Tensor& k,
150 const Tensor& v,
151 double dropout_probability,
152 bool is_causal,
153 bool return_softmaxstats) {
154 memset(¶ms, 0, sizeof(MHAParams));
155 params.device_id = at::cuda::current_device();
156 params.dataType = fe::DataType_t::HALF;
157 if (q.scalar_type() == kBFloat16) {
158 params.dataType = fe::DataType_t::BFLOAT16;
159 }
160 params.b = b;
161 params.h = h;
162 params.d_qk = d_qk;
163 params.d_v = d_v;
164 params.s_q = s_q;
165 params.s_kv = s_kv;
166 params.dropout_probability = dropout_probability;
167 params.is_causal = is_causal;
168 params.return_softmaxstats = return_softmaxstats;
169 TORCH_INTERNAL_ASSERT(
170 q.sizes().size() == MAX_MHA_DIM,
171 "Q tensor has unexpected number of dims, please report a bug to PyTorch.");
172 TORCH_INTERNAL_ASSERT(
173 q.strides().size() == MAX_MHA_DIM,
174 "Q tensor has unexpected number of dims, please report a bug to PyTorch.");
175 TORCH_INTERNAL_ASSERT(
176 k.sizes().size() == MAX_MHA_DIM,
177 "K tensor has unexpected number of dims, please report a bug to PyTorch.");
178 TORCH_INTERNAL_ASSERT(
179 k.strides().size() == MAX_MHA_DIM,
180 "K tensor has unexpected number of dims, please report a bug to PyTorch.");
181 TORCH_INTERNAL_ASSERT(
182 v.sizes().size() == MAX_MHA_DIM,
183 "V tensor has unexpected number of dims, please report a bug to PyTorch.");
184 TORCH_INTERNAL_ASSERT(
185 v.strides().size() == MAX_MHA_DIM,
186 "V tensor has unexpected number of dims, please report a bug to PyTorch.");
187 std::copy(q.sizes().begin(), q.sizes().end(), params.q_dim.begin());
188 std::copy(q.strides().begin(), q.strides().end(), params.q_stride.begin());
189 std::copy(k.sizes().begin(), k.sizes().end(), params.k_dim.begin());
190 std::copy(k.strides().begin(), k.strides().end(), params.k_stride.begin());
191 std::copy(v.sizes().begin(), v.sizes().end(), params.v_dim.begin());
192 std::copy(v.strides().begin(), v.strides().end(), params.v_stride.begin());
193 }
194
195 struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
MHACacheKeyWrapperat::native::MHACacheKeyWrapper196 MHACacheKeyWrapper(
197 int64_t b,
198 int64_t h,
199 int64_t s_q,
200 int64_t s_kv,
201 int64_t d_qk,
202 int64_t d_v,
203 const Tensor& q,
204 const Tensor& k,
205 const Tensor& v,
206 double dropout_probability,
207 bool is_causal,
208 bool return_softmaxstats) {
209 setMHAParams(
210 this->pod,
211 b,
212 h,
213 s_q,
214 s_kv,
215 d_qk,
216 d_v,
217 q,
218 k,
219 v,
220 dropout_probability,
221 is_causal,
222 return_softmaxstats);
223 }
224 };
225
226 template <typename T, typename KeyType>
227 struct MHAGraphCache {
228 std::unordered_map<KeyType, T, ParamsWrapperHash<KeyType>> engine_cache;
229
230 // no mutexes here as caches are now thread local for v8, can also return a
231 // pointer to the Execution Plan if we know it will not be invalidated by
232 // another thread
findat::native::MHAGraphCache233 T* find(const KeyType& key) {
234 auto it = engine_cache.find(key);
235 if (it == engine_cache.end()) {
236 return nullptr;
237 }
238 return &(it->second);
239 }
240
updateat::native::MHAGraphCache241 void update(const KeyType& key, T& results) {
242 engine_cache.erase(key);
243 engine_cache.emplace(key, std::move(results));
244 }
245 };
246
247 // @eqy: use thread local caches as cuDNN Execution Plans are not guaranteed to
248 // be thread safe across all engines see Limitations in
249 // https://docs.nvidia.com/deeplearning/cudnn/release-notes/index.html
250 thread_local MHAGraphCache<graph_and_tensors, MHACacheKeyWrapper> mhagraphcache;
251 thread_local MHAGraphCache<graph_and_tensors_backward, MHACacheKeyWrapper>
252 mhagraphbackwardcache;
253
build_graph_and_tensors(int64_t b,int64_t h,int64_t s_q,int64_t s_kv,int64_t d_qk,int64_t d_v,float scaling_factor,bool return_softmaxstats,bool is_causal,double dropout_probability,const Tensor & q,const Tensor & k,const Tensor & v,Tensor & softmaxstats,Tensor & o,Tensor & dropoutseed,Tensor & dropoutoffset,cudnnHandle_t & handle,MHAParams & params)254 auto build_graph_and_tensors(
255 int64_t b,
256 int64_t h,
257 int64_t s_q,
258 int64_t s_kv,
259 int64_t d_qk,
260 int64_t d_v,
261 float scaling_factor,
262 bool return_softmaxstats,
263 bool is_causal,
264 double dropout_probability,
265 const Tensor& q,
266 const Tensor& k,
267 const Tensor& v,
268 Tensor& softmaxstats,
269 Tensor& o,
270 Tensor& dropoutseed,
271 Tensor& dropoutoffset,
272 cudnnHandle_t& handle,
273 MHAParams& params) {
274 auto dtype = fe::DataType_t::HALF;
275 if (q.scalar_type() == kBFloat16) {
276 dtype = fe::DataType_t::BFLOAT16;
277 }
278 auto mha_graph = std::make_shared<fe::graph::Graph>();
279 // We're baking in float accumulation and scale types
280 // in theory the graph may support other types, but they
281 // have not been tested
282 mha_graph->set_io_data_type(dtype)
283 .set_intermediate_data_type(fe::DataType_t::FLOAT)
284 .set_compute_data_type(fe::DataType_t::FLOAT);
285 auto Q = mha_graph->tensor(
286 fe::graph::Tensor_attributes()
287 .set_name("Q")
288 .set_dim(
289 std::vector<int64_t>(params.q_dim.begin(), params.q_dim.end()))
290 .set_stride(std::vector<int64_t>(
291 params.q_stride.begin(), params.q_stride.end())));
292 auto K = mha_graph->tensor(
293 fe::graph::Tensor_attributes()
294 .set_name("K")
295 .set_dim(
296 std::vector<int64_t>(params.k_dim.begin(), params.k_dim.end()))
297 .set_stride(std::vector<int64_t>(
298 params.k_stride.begin(), params.k_stride.end())));
299 auto V = mha_graph->tensor(
300 fe::graph::Tensor_attributes()
301 .set_name("V")
302 .set_dim(
303 std::vector<int64_t>(params.v_dim.begin(), params.v_dim.end()))
304 .set_stride(std::vector<int64_t>(
305 params.v_stride.begin(), params.v_stride.end())));
306 auto attn_scale =
307 mha_graph->tensor(fe::graph::Tensor_attributes()
308 .set_name("Attn_scale")
309 .set_dim({1, 1, 1, 1})
310 .set_stride({1, 1, 1, 1})
311 .set_is_pass_by_value(true)
312 .set_data_type(fe::DataType_t::FLOAT));
313 // TODO(eqy): support bias in the future in a follow-up PR
314 // auto bias = mha_graph->tensor(fe::graph::Tensor_attributes()
315 // .set_name("bias")
316 // .set_dim({b, 1, s_q, s_kv})
317 // .set_stride({s_q * s_kv, s_q * s_kv, s_kv, 1}));
318 auto seed = mha_graph->tensor(fe::graph::Tensor_attributes()
319 .set_name("Seed")
320 .set_dim({1, 1, 1, 1})
321 .set_stride({1, 1, 1, 1})
322 .set_data_type(fe::DataType_t::INT32));
323 auto offset = mha_graph->tensor(fe::graph::Tensor_attributes()
324 .set_name("Offset")
325 .set_dim({1, 1, 1, 1})
326 .set_stride({1, 1, 1, 1})
327 .set_data_type(fe::DataType_t::INT32));
328 auto scaled_dot_product_flash_attention_options =
329 fe::graph::SDPA_attributes()
330 .set_name("CUDNN_SDPA")
331 .set_is_inference(return_softmaxstats == false)
332 .set_causal_mask(is_causal)
333 .set_attn_scale(attn_scale)
334 .set_dropout(dropout_probability, seed, offset);
335 // Optional bias in flash attention is only supported 8.9.3 onwards
336 if (cudnnGetVersion() >= 8904) {
337 // scaled_dot_product_flash_attention_options.set_alibi_mask(true);
338 }
339
340 auto seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
341 .set_name("Seq_q")
342 .set_dim({b, 1, 1, 1})
343 .set_stride({1, 1, 1, 1})
344 .set_data_type(fe::DataType_t::INT32));
345 auto seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes()
346 .set_name("Seq_kv")
347 .set_dim({b, 1, 1, 1})
348 .set_stride({1, 1, 1, 1})
349 .set_data_type(fe::DataType_t::INT32));
350
351 // if (cudnnGetVersion() >= 8903) {
352 // scaled_dot_product_flash_attention_options.set_bias(bias)
353 // .set_padding_mask(true)
354 // .set_seq_len_q(seq_q)
355 // .set_seq_len_kv(seq_kv);
356 // }
357
358 auto [O, Stats] =
359 mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options);
360 O->set_output(true)
361 .set_dim(std::vector<int64_t>(
362 o.sizes().data(), o.sizes().data() + o.sizes().size()))
363 .set_stride(std::vector<int64_t>(
364 o.strides().data(), o.strides().data() + o.strides().size()));
365
366 if (Stats) {
367 Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT);
368 }
369
370 AT_CUDNN_FRONTEND_CHECK(mha_graph->validate());
371 AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle));
372 AT_CUDNN_FRONTEND_CHECK(
373 mha_graph->create_execution_plans({fe::HeurMode_t::A}));
374 AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle));
375 AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle));
376
377 return std::make_tuple(
378 std::move(mha_graph),
379 std::move(Q),
380 std::move(K),
381 std::move(V),
382 std::move(attn_scale),
383 std::move(seed),
384 std::move(offset),
385 std::move(O),
386 std::move(Stats));
387 }
388
build_graph_and_tensors_backward(int64_t b,int64_t h,int64_t s_q,int64_t s_kv,int64_t d_qk,int64_t d_v,float scaling_factor,bool is_causal,float dropout_probability,const Tensor & q,const Tensor & k,const Tensor & v,const Tensor & o,const Tensor & dO,const Tensor & softmaxstats,Tensor & dQ,Tensor & dK,Tensor & dV,const Tensor & dropoutseed,const Tensor & dropoutoffset,cudnnHandle_t & handle,MHAParams & params)389 auto build_graph_and_tensors_backward(
390 int64_t b,
391 int64_t h,
392 int64_t s_q,
393 int64_t s_kv,
394 int64_t d_qk,
395 int64_t d_v,
396 float scaling_factor,
397 bool is_causal,
398 float dropout_probability,
399 const Tensor& q,
400 const Tensor& k,
401 const Tensor& v,
402 const Tensor& o,
403 const Tensor& dO,
404 const Tensor& softmaxstats,
405 Tensor& dQ,
406 Tensor& dK,
407 Tensor& dV,
408 const Tensor& dropoutseed,
409 const Tensor& dropoutoffset,
410 cudnnHandle_t& handle,
411 MHAParams& params) {
412 auto dtype = fe::DataType_t::HALF;
413 if (q.scalar_type() == kBFloat16) {
414 dtype = fe::DataType_t::BFLOAT16;
415 }
416 auto mha_graph = std::make_shared<fe::graph::Graph>();
417 // We're baking in float accumulation and scale types
418 // in theory the graph may support other types, but they
419 // have not been tested
420 mha_graph->set_io_data_type(dtype)
421 .set_intermediate_data_type(fe::DataType_t::FLOAT)
422 .set_compute_data_type(fe::DataType_t::FLOAT);
423 auto Q = mha_graph->tensor(
424 fe::graph::Tensor_attributes()
425 .set_name("Q")
426 .set_dim(std::vector<int64_t>(q.sizes().begin(), q.sizes().end()))
427 .set_stride(
428 std::vector<int64_t>(q.strides().begin(), q.strides().end())));
429 auto K = mha_graph->tensor(
430 fe::graph::Tensor_attributes()
431 .set_name("K")
432 .set_dim(std::vector<int64_t>(k.sizes().begin(), k.sizes().end()))
433 .set_stride(
434 std::vector<int64_t>(k.strides().begin(), k.strides().end())));
435 auto V = mha_graph->tensor(
436 fe::graph::Tensor_attributes()
437 .set_name("V")
438 .set_dim(std::vector<int64_t>(v.sizes().begin(), v.sizes().end()))
439 .set_stride(
440 std::vector<int64_t>(v.strides().begin(), v.strides().end())));
441 auto attn_scale =
442 mha_graph->tensor(fe::graph::Tensor_attributes()
443 .set_name("Attn_scale")
444 .set_dim({1, 1, 1, 1})
445 .set_stride({1, 1, 1, 1})
446 .set_is_pass_by_value(true)
447 .set_data_type(fe::DataType_t::FLOAT));
448 auto Seed = mha_graph->tensor(fe::graph::Tensor_attributes()
449 .set_name("Seed")
450 .set_dim({1, 1, 1, 1})
451 .set_stride({1, 1, 1, 1})
452 .set_data_type(fe::DataType_t::INT32));
453 auto Offset = mha_graph->tensor(fe::graph::Tensor_attributes()
454 .set_name("Offset")
455 .set_dim({1, 1, 1, 1})
456 .set_stride({1, 1, 1, 1})
457 .set_data_type(fe::DataType_t::INT32));
458 auto O = mha_graph->tensor(
459 fe::graph::Tensor_attributes()
460 .set_name("O")
461 .set_dim(std::vector<int64_t>(o.sizes().begin(), o.sizes().end()))
462 .set_stride(
463 std::vector<int64_t>(o.strides().begin(), o.strides().end())));
464 auto STATS = mha_graph->tensor(
465 fe::graph::Tensor_attributes()
466 .set_name("Stats")
467 .set_dim(std::vector<int64_t>(
468 softmaxstats.sizes().begin(), softmaxstats.sizes().end()))
469 .set_stride(std::vector<int64_t>(
470 softmaxstats.strides().begin(), softmaxstats.strides().end()))
471 .set_data_type(fe::DataType_t::FLOAT));
472 auto DO = mha_graph->tensor(
473 fe::graph::Tensor_attributes()
474 .set_name("DO")
475 .set_dim(std::vector<int64_t>(dO.sizes().begin(), dO.sizes().end()))
476 .set_stride(
477 std::vector<int64_t>(dO.strides().begin(), dO.strides().end())));
478 auto sdpa_backward_options = fe::graph::SDPA_backward_attributes()
479 .set_name("CUDNN_SDPA_BACKWARD")
480 .set_causal_mask(is_causal)
481 .set_attn_scale(attn_scale);
482 if (dropout_probability != 0.0f) {
483 sdpa_backward_options.set_dropout(dropout_probability, Seed, Offset);
484 }
485 auto [DQ, DK, DV] =
486 mha_graph->sdpa_backward(Q, K, V, O, DO, STATS, sdpa_backward_options);
487 DQ->set_output(true)
488 .set_dim(std::vector<int64_t>(dQ.sizes().begin(), dQ.sizes().end()))
489 .set_stride(
490 std::vector<int64_t>(dQ.strides().begin(), dQ.strides().end()));
491 DK->set_output(true)
492 .set_dim(std::vector<int64_t>(dK.sizes().begin(), dK.sizes().end()))
493 .set_stride(
494 std::vector<int64_t>(dK.strides().begin(), dK.strides().end()));
495 DV->set_output(true)
496 .set_dim(std::vector<int64_t>(dV.sizes().begin(), dV.sizes().end()))
497 .set_stride(
498 std::vector<int64_t>(dV.strides().begin(), dV.strides().end()));
499 AT_CUDNN_FRONTEND_CHECK(mha_graph->validate());
500 AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle));
501 AT_CUDNN_FRONTEND_CHECK(
502 mha_graph->create_execution_plans({fe::HeurMode_t::A}));
503 AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle));
504 AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle));
505 return std::make_tuple(
506 std::move(mha_graph),
507 std::move(Q),
508 std::move(K),
509 std::move(V),
510 std::move(attn_scale),
511 std::move(Seed),
512 std::move(Offset),
513 std::move(O),
514 std::move(DO),
515 std::move(STATS),
516 std::move(DQ),
517 std::move(DK),
518 std::move(DV));
519 }
520
run_cudnn_SDP_fprop(int64_t b,int64_t h,int64_t s_q,int64_t s_kv,int64_t d_qk,int64_t d_v,float scaling_factor,bool return_softmaxstats,bool is_causal,double dropout_probability,const Tensor & q,const Tensor & k,const Tensor & v,Tensor & softmaxstats,Tensor & o,Tensor & dropoutseed,Tensor & dropoutoffset)521 void run_cudnn_SDP_fprop(
522 int64_t b,
523 int64_t h,
524 int64_t s_q,
525 int64_t s_kv,
526 int64_t d_qk,
527 int64_t d_v,
528 float scaling_factor,
529 bool return_softmaxstats,
530 bool is_causal,
531 double dropout_probability,
532 const Tensor& q,
533 const Tensor& k,
534 const Tensor& v,
535 Tensor& softmaxstats,
536 Tensor& o,
537 Tensor& dropoutseed,
538 Tensor& dropoutoffset) {
539 cudnnHandle_t handle = getCudnnHandle();
540 o = at::empty_strided(
541 {b, h, s_q, d_v}, {s_q * h * d_v, d_v, h * d_v, 1}, q.options());
542 if (return_softmaxstats) {
543 // TODO(eqy): verify that this is correct
544 softmaxstats = at::empty({b, h, s_q}, q.options().dtype(kFloat));
545 }
546
547 auto key = MHACacheKeyWrapper(
548 b,
549 h,
550 s_q,
551 s_kv,
552 d_qk,
553 d_v,
554 q,
555 k,
556 v,
557 dropout_probability,
558 is_causal,
559 return_softmaxstats);
560 auto graph_and_tensors_ptr = mhagraphcache.find(key);
561 graph_and_tensors graph_and_tensors_values;
562 if (graph_and_tensors_ptr) {
563 graph_and_tensors_values = *graph_and_tensors_ptr;
564 } else {
565 graph_and_tensors_values = build_graph_and_tensors(
566 b,
567 h,
568 s_q,
569 s_kv,
570 d_qk,
571 d_v,
572 scaling_factor,
573 return_softmaxstats,
574 is_causal,
575 dropout_probability,
576 q,
577 k,
578 v,
579 softmaxstats,
580 o,
581 dropoutseed,
582 dropoutoffset,
583 handle,
584 key.pod);
585 }
586 auto [mha_graph, Q, K, V, attn_scale, seed, offset, O, Stats] =
587 graph_and_tensors_values;
588 std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*>
589 variant_pack = {
590 {Q, q.data_ptr()},
591 {K, k.data_ptr()},
592 {V, v.data_ptr()},
593 {attn_scale, &scaling_factor},
594 //{bias, bias.data_ptr()},
595 {seed, dropoutseed.data_ptr()},
596 {offset, dropoutoffset.data_ptr()},
597 {O, o.data_ptr()}};
598 if (return_softmaxstats) {
599 variant_pack[Stats] = softmaxstats.data_ptr();
600 }
601 auto workspace_size = mha_graph->get_workspace_size();
602 auto workspace_ptr =
603 c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
604 TORCH_CHECK(
605 mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
606 mhagraphcache.update(key, graph_and_tensors_values);
607 }
608
run_cudnn_SDP_bprop(int64_t b,int64_t h,int64_t s_q,int64_t s_kv,int64_t d_qk,int64_t d_v,float scaling_factor,bool is_causal,float dropout_probability,const Tensor & q,const Tensor & k,const Tensor & v,const Tensor & o,const Tensor & dO,const Tensor & softmaxstats,Tensor & dQ,Tensor & dK,Tensor & dV,const Tensor & dropoutseed,const Tensor & dropoutoffset)609 void run_cudnn_SDP_bprop(
610 int64_t b,
611 int64_t h,
612 int64_t s_q,
613 int64_t s_kv,
614 int64_t d_qk,
615 int64_t d_v,
616 float scaling_factor,
617 bool is_causal,
618 float dropout_probability,
619 const Tensor& q,
620 const Tensor& k,
621 const Tensor& v,
622 const Tensor& o,
623 const Tensor& dO,
624 const Tensor& softmaxstats,
625 Tensor& dQ,
626 Tensor& dK,
627 Tensor& dV,
628 const Tensor& dropoutseed,
629 const Tensor& dropoutoffset) {
630 cudnnHandle_t handle = getCudnnHandle();
631 auto key = MHACacheKeyWrapper(
632 b,
633 h,
634 s_q,
635 s_kv,
636 d_qk,
637 d_v,
638 q,
639 k,
640 v,
641 dropout_probability,
642 is_causal,
643 true);
644 auto graph_and_tensors_backward_ptr = mhagraphbackwardcache.find(key);
645 graph_and_tensors_backward graph_and_tensors_backward_values;
646 if (graph_and_tensors_backward_ptr) {
647 graph_and_tensors_backward_values = *graph_and_tensors_backward_ptr;
648 } else {
649 graph_and_tensors_backward_values = build_graph_and_tensors_backward(
650 b,
651 h,
652 s_q,
653 s_kv,
654 d_qk,
655 d_v,
656 scaling_factor,
657 is_causal,
658 dropout_probability,
659 q,
660 k,
661 v,
662 o,
663 dO,
664 softmaxstats,
665 dQ,
666 dK,
667 dV,
668 dropoutseed,
669 dropoutoffset,
670 handle,
671 key.pod);
672 }
673 auto
674 [mha_graph, Q, K, V, attn_scale, Seed, Offset, O, Do, Stats, Dq, Dk, Dv] =
675 graph_and_tensors_backward_values;
676 std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*>
677 variant_pack = {// inputs
678 {Q, q.data_ptr()},
679 {K, k.data_ptr()},
680 {V, v.data_ptr()},
681 {O, o.data_ptr()},
682 {Do, dO.data_ptr()},
683 {Stats, softmaxstats.data_ptr()},
684 // outputs
685 {Dq, dQ.data_ptr()},
686 {Dk, dK.data_ptr()},
687 {Dv, dV.data_ptr()},
688 // pass by value
689 {attn_scale, &scaling_factor}};
690 if (dropout_probability != 0.0f) {
691 variant_pack[Seed] = dropoutseed.data_ptr();
692 variant_pack[Offset] = dropoutoffset.data_ptr();
693 }
694 auto workspace_size = mha_graph->get_workspace_size();
695 auto workspace_ptr =
696 c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
697 TORCH_CHECK(!workspace_size || workspace_ptr.get());
698 TORCH_CHECK(
699 mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
700 mhagraphbackwardcache.update(key, graph_and_tensors_backward_values);
701 }
702
703 } // namespace native
704 } // namespace at
705 #endif
706