Home
last modified time | relevance | path

Searched defs:is_causal (Results 1 – 24 of 24) sorted by relevance

/aosp_15_r20/external/executorch/extension/llm/custom_ops/
H A Dop_sdpa_test.cpp28 bool is_causal, in op_scaled_dot_product_attention()
99 bool is_causal = false; in TEST() local
142 bool is_causal = false; in TEST() local
172 bool is_causal = true; in TEST() local
240 bool is_causal = false; in TEST() local
316 bool is_causal = false; in TEST() local
503 bool is_causal = false; in TEST() local
H A Dop_sdpa_with_kv_cache_test.cpp31 bool is_causal, in op_sdpa_with_kv_cache()
147 bool is_causal = false; in TEST() local
389 bool is_causal = false; in TEST() local
592 bool is_causal = false; in TEST() local
837 bool is_causal = false; in TEST() local
H A Dop_sdpa_aot.cpp32 const bool is_causal, in sdpa_with_kv_cache_out_no_context()
65 const bool is_causal, in sdpa_with_kv_cache_aten()
94 const bool is_causal, in custom_sdpa_out_no_context()
121 const bool is_causal, in custom_sdpa_aten()
H A Dsdpa_with_kv_cache.py49 is_causal, argument
116 is_causal=False, argument
152 is_causal=False, argument
H A Dop_sdpa.cpp223 bool is_causal, in cpu_flash_attention()
765 const bool is_causal, in flash_attention_kernel_out()
848 const bool is_causal, in custom_sdpa_out()
997 const bool is_causal, in sdpa_with_kv_cache_out()
/aosp_15_r20/external/pytorch/aten/src/ATen/native/cudnn/
H A DMHA.cpp20 bool is_causal, in run_cudnn_SDP_fprop()
41 bool is_causal, in run_cudnn_SDP_bprop()
136 bool is_causal; member
152 bool is_causal, in setMHAParams()
263 bool is_causal, in build_graph_and_tensors()
397 bool is_causal, in build_graph_and_tensors_backward()
530 bool is_causal, in run_cudnn_SDP_fprop()
617 bool is_causal, in run_cudnn_SDP_bprop()
/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/
H A Dattention.cpp427 …const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> s… in _fused_sdp_choice_cpp()
451 bool is_causal, in _fused_sdp_choice_meta()
485 bool is_causal, in validate_sdpa_input()
602 …const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> s… in handle_private_use()
655 bool is_causal, in scaled_dot_product_attention()
717 const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, in _scaled_dot_product_attention_math()
776 bool is_causal, in _scaled_dot_product_flash_attention_cpu()
824 bool is_causal, in _scaled_dot_product_flash_attention_cpu_backward()
H A Dsdp_utils_cpp.h50 bool is_causal; member
/aosp_15_r20/external/pytorch/torch/nested/_internal/
H A Dsdpa.py29 is_causal=False, argument
264 def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal): argument
623 is_causal=False, argument
/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/
H A Dattention.cu689 bool is_causal, in _scaled_dot_product_flash_attention_cuda()
743 bool is_causal, in _scaled_dot_product_cudnn_attention_cuda()
796 bool is_causal, in _scaled_dot_product_efficient_attention_cuda()
830 …const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> s… in _fused_sdp_choice_cuda()
852 bool is_causal, in _flash_attention_forward()
1094 bool is_causal; in _efficient_attention_forward() local
H A Dattention_backward.cu73 bool is_causal, in _flash_attention_backward()
182 bool is_causal, in _scaled_dot_product_cudnn_attention_backward_cuda()
423 bool is_causal; in _efficient_attention_backward() local
744 bool is_causal, in _scaled_dot_product_flash_attention_backward_cuda()
/aosp_15_r20/external/pytorch/aten/src/ATen/native/nested/cuda/
H A DNestedTensorTransformerFunctions.cpp234 bool is_causal, in _scaled_dot_product_flash_attention_nestedtensor_cuda()
289 bool is_causal, in _scaled_dot_product_efficient_attention_nestedtensor_cuda()
338 bool is_causal, in _scaled_dot_product_flash_attention_backward_nested()
/aosp_15_r20/external/executorch/backends/vulkan/test/op_tests/
H A Dsdpa_test.cpp44 const bool is_causal, in sdpa_with_kv_cache_out_no_context()
77 const bool is_causal, in sdpa_with_kv_cache_aten()
164 const bool is_causal, in sdpa_reference_impl()
/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/
H A Dflash_api.cpp357 bool is_causal, in mha_fwd()
557 bool is_causal, in mha_varlen_fwd()
819 const bool is_causal, in mha_bwd()
1035 const bool is_causal, in mha_varlen_bwd()
1266 bool is_causal, in mha_fwd_kvcache()
H A Dflash.h127 bool is_causal; member
/aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/
H A DFlashAttentionKernel.cpp181 bool is_causal, in cpu_flash_attention()
431 bool is_causal, in cpu_flash_attention_backward()
736 bool is_causal, in flash_attention_kernel_impl()
787 bool is_causal, in flash_attention_backward_kernel_impl()
/aosp_15_r20/external/pytorch/test/cpp_extensions/
H A Dopen_registration_extension.cpp131 …const std::optional<at::Tensor> & attn_mask, double dropout_p, bool is_causal, std::optional<doubl… in _fused_sdp_choice_privateuse1()
451 bool is_causal, in custom_scaled_dot_product_fused_attention_overrideable()
486 bool is_causal, in custom_scaled_dot_product_fused_attention_overrideable_backward()
/aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_torch/
H A Dshim_common.cpp517 int is_causal, in aoti_torch__scaled_dot_product_flash_attention_v2()
567 bool is_causal, in aoti_torch__scaled_dot_product_flash_attention()
607 int is_causal, in aoti_torch__scaled_dot_product_efficient_attention()
/aosp_15_r20/external/pytorch/aten/src/ATen/functorch/
H A DBatchRulesLinearAlgebra.cpp495 bool is_causal, in _scaled_dot_product_flash_attention_batch_rule()
545 bool is_causal, in _scaled_dot_product_efficient_attention_batch_rule()
586 bool is_causal, in _scaled_dot_product_cudnn_attention_batch_rule()
/aosp_15_r20/external/pytorch/torch/nn/attention/
H A D_utils.py48 is_causal=False, argument
/aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/impl/
H A DSDPA.cpp192 const ValueRef is_causal = args[arg_idx++]; in sdpa_with_kv_cache_impl() local
/aosp_15_r20/external/pytorch/test/
H A Dtest_transformers.py969 …def test_scaled_dot_product_attention(self, device, input_dim, attn_mask_dim, is_causal, dropout_p… argument
1703 def _get_block_size_n(device, head_dim, is_dropout, is_causal): argument
H A Dtest_decomp.py1188 self, query_layer, key_layer, value_layer, mask=None, is_causal=True argument
/aosp_15_r20/external/pytorch/torch/csrc/
H A DModule.cpp1996 bool enable_gqa) { in initModule()