Home
last modified time | relevance | path

Searched defs:head_size_rounded (Results 1 – 1 of 1) sorted by relevance

/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/
H A Dflash_api.cpp299 const int head_size_rounded, const float p_dropout, in set_params_splitkv()
444 const int head_size_rounded = round_multiple(head_size, 32); in mha_fwd() local
683 const int head_size_rounded = round_multiple(head_size, 32); in mha_varlen_fwd() local
881 const int head_size_rounded = round_multiple(head_size, 32); in mha_bwd() local
1103 const int head_size_rounded = round_multiple(head_size, 32); in mha_varlen_bwd() local
1379 const int head_size_rounded = round_multiple(head_size, 32); in mha_fwd_kvcache() local