Home
last modified time | relevance | path

Searched defs:Headdim (Results 1 – 2 of 2) sorted by relevance

/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/
H A Dflash_fwd_launch_template.h176 constexpr static int Headdim = 32; in run_mha_fwd_hdim32() local
186 constexpr static int Headdim = 64; in run_mha_fwd_hdim64() local
208 constexpr static int Headdim = 96; in run_mha_fwd_hdim96() local
234 constexpr static int Headdim = 128; in run_mha_fwd_hdim128() local
271 constexpr static int Headdim = 160; in run_mha_fwd_hdim160() local
301 constexpr static int Headdim = 192; in run_mha_fwd_hdim192() local
320 constexpr static int Headdim = 224; in run_mha_fwd_hdim224() local
349 constexpr static int Headdim = 256; in run_mha_fwd_hdim256() local
H A Dflash_bwd_launch_template.h138 constexpr static int Headdim = 32; in run_mha_bwd_hdim32() local
162 constexpr static int Headdim = 64; in run_mha_bwd_hdim64() local
207 constexpr static int Headdim = 96; in run_mha_bwd_hdim96() local
233 constexpr static int Headdim = 128; in run_mha_bwd_hdim128() local
267 constexpr static int Headdim = 160; in run_mha_bwd_hdim160() local
287 constexpr static int Headdim = 192; in run_mha_bwd_hdim192() local
307 constexpr static int Headdim = 224; in run_mha_bwd_hdim224() local
315 constexpr static int Headdim = 256; in run_mha_bwd_hdim256() local