Home
last modified time | relevance | path

Searched defs:max_seqlen_batch_k (Results 1 – 3 of 3) sorted by relevance

/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/
H A Dattention_backward.cu71 int64_t max_seqlen_batch_k, in _flash_attention_backward()
199 const int64_t max_seqlen_batch_k = key.size(2); in _scaled_dot_product_cudnn_attention_backward_cuda() local
742 const int64_t max_seqlen_batch_k, in _scaled_dot_product_flash_attention_backward_cuda()
H A Dattention.cu699 const int64_t max_seqlen_batch_k = key.size(2); in _scaled_dot_product_flash_attention_cuda() local
756 const int64_t max_seqlen_batch_k = key.size(2); in _scaled_dot_product_cudnn_attention_cuda() local
850 int64_t max_seqlen_batch_k, in _flash_attention_forward()
/aosp_15_r20/external/pytorch/aten/src/ATen/native/nested/cuda/
H A DNestedTensorTransformerFunctions.cpp336 const int64_t max_seqlen_batch_k, in _scaled_dot_product_flash_attention_backward_nested()