Home
last modified time | relevance | path

Searched defs:softmax_lse (Results 1 – 3 of 3) sorted by relevance

/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/
H A Dflash_api.cpp454 auto softmax_lse = at::empty({batch_size, num_heads, seqlen_q }, opts.dtype(at::kFloat)); in mha_fwd() local
693 auto softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); in mha_varlen_fwd() local
812 const at::Tensor &softmax_lse, // b x h x seqlen_q in mha_bwd()
1023 const at::Tensor &softmax_lse, // b x h x s softmax logsumexp in mha_varlen_bwd()
1389 auto softmax_lse = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); in mha_fwd_kvcache() local
/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/
H A Dattention_backward.cu439 at::Tensor softmax_lse = logsumexp.view({B * nH, max_seqlen_q}); in _efficient_attention_backward() local
H A Dattention.cu1089 at::Tensor softmax_lse = logsumexp.view({B * num_heads, max_seqlen_q}); in _efficient_attention_forward() local