/aosp_15_r20/external/executorch/extension/llm/custom_ops/ |
H A D | op_sdpa_test.cpp | 28 bool is_causal, in op_scaled_dot_product_attention() 99 bool is_causal = false; in TEST() local 142 bool is_causal = false; in TEST() local 172 bool is_causal = true; in TEST() local 240 bool is_causal = false; in TEST() local 316 bool is_causal = false; in TEST() local 503 bool is_causal = false; in TEST() local
|
H A D | op_sdpa_with_kv_cache_test.cpp | 31 bool is_causal, in op_sdpa_with_kv_cache() 147 bool is_causal = false; in TEST() local 389 bool is_causal = false; in TEST() local 592 bool is_causal = false; in TEST() local 837 bool is_causal = false; in TEST() local
|
H A D | op_sdpa_aot.cpp | 32 const bool is_causal, in sdpa_with_kv_cache_out_no_context() 65 const bool is_causal, in sdpa_with_kv_cache_aten() 94 const bool is_causal, in custom_sdpa_out_no_context() 121 const bool is_causal, in custom_sdpa_aten()
|
H A D | sdpa_with_kv_cache.py | 49 is_causal, argument 116 is_causal=False, argument 152 is_causal=False, argument
|
H A D | op_sdpa.cpp | 223 bool is_causal, in cpu_flash_attention() 765 const bool is_causal, in flash_attention_kernel_out() 848 const bool is_causal, in custom_sdpa_out() 997 const bool is_causal, in sdpa_with_kv_cache_out()
|
/aosp_15_r20/external/pytorch/aten/src/ATen/native/cudnn/ |
H A D | MHA.cpp | 20 bool is_causal, in run_cudnn_SDP_fprop() 41 bool is_causal, in run_cudnn_SDP_bprop() 136 bool is_causal; member 152 bool is_causal, in setMHAParams() 263 bool is_causal, in build_graph_and_tensors() 397 bool is_causal, in build_graph_and_tensors_backward() 530 bool is_causal, in run_cudnn_SDP_fprop() 617 bool is_causal, in run_cudnn_SDP_bprop()
|
/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/ |
H A D | attention.cpp | 427 …const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> s… in _fused_sdp_choice_cpp() 451 bool is_causal, in _fused_sdp_choice_meta() 485 bool is_causal, in validate_sdpa_input() 602 …const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> s… in handle_private_use() 655 bool is_causal, in scaled_dot_product_attention() 717 const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, in _scaled_dot_product_attention_math() 776 bool is_causal, in _scaled_dot_product_flash_attention_cpu() 824 bool is_causal, in _scaled_dot_product_flash_attention_cpu_backward()
|
H A D | sdp_utils_cpp.h | 50 bool is_causal; member
|
/aosp_15_r20/external/pytorch/torch/nested/_internal/ |
H A D | sdpa.py | 29 is_causal=False, argument 264 def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal): argument 623 is_causal=False, argument
|
/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/ |
H A D | attention.cu | 689 bool is_causal, in _scaled_dot_product_flash_attention_cuda() 743 bool is_causal, in _scaled_dot_product_cudnn_attention_cuda() 796 bool is_causal, in _scaled_dot_product_efficient_attention_cuda() 830 …const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> s… in _fused_sdp_choice_cuda() 852 bool is_causal, in _flash_attention_forward() 1094 bool is_causal; in _efficient_attention_forward() local
|
H A D | attention_backward.cu | 73 bool is_causal, in _flash_attention_backward() 182 bool is_causal, in _scaled_dot_product_cudnn_attention_backward_cuda() 423 bool is_causal; in _efficient_attention_backward() local 744 bool is_causal, in _scaled_dot_product_flash_attention_backward_cuda()
|
/aosp_15_r20/external/pytorch/aten/src/ATen/native/nested/cuda/ |
H A D | NestedTensorTransformerFunctions.cpp | 234 bool is_causal, in _scaled_dot_product_flash_attention_nestedtensor_cuda() 289 bool is_causal, in _scaled_dot_product_efficient_attention_nestedtensor_cuda() 338 bool is_causal, in _scaled_dot_product_flash_attention_backward_nested()
|
/aosp_15_r20/external/executorch/backends/vulkan/test/op_tests/ |
H A D | sdpa_test.cpp | 44 const bool is_causal, in sdpa_with_kv_cache_out_no_context() 77 const bool is_causal, in sdpa_with_kv_cache_aten() 164 const bool is_causal, in sdpa_reference_impl()
|
/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/ |
H A D | flash_api.cpp | 357 bool is_causal, in mha_fwd() 557 bool is_causal, in mha_varlen_fwd() 819 const bool is_causal, in mha_bwd() 1035 const bool is_causal, in mha_varlen_bwd() 1266 bool is_causal, in mha_fwd_kvcache()
|
H A D | flash.h | 127 bool is_causal; member
|
/aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/ |
H A D | FlashAttentionKernel.cpp | 181 bool is_causal, in cpu_flash_attention() 431 bool is_causal, in cpu_flash_attention_backward() 736 bool is_causal, in flash_attention_kernel_impl() 787 bool is_causal, in flash_attention_backward_kernel_impl()
|
/aosp_15_r20/external/pytorch/test/cpp_extensions/ |
H A D | open_registration_extension.cpp | 131 …const std::optional<at::Tensor> & attn_mask, double dropout_p, bool is_causal, std::optional<doubl… in _fused_sdp_choice_privateuse1() 451 bool is_causal, in custom_scaled_dot_product_fused_attention_overrideable() 486 bool is_causal, in custom_scaled_dot_product_fused_attention_overrideable_backward()
|
/aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_torch/ |
H A D | shim_common.cpp | 517 int is_causal, in aoti_torch__scaled_dot_product_flash_attention_v2() 567 bool is_causal, in aoti_torch__scaled_dot_product_flash_attention() 607 int is_causal, in aoti_torch__scaled_dot_product_efficient_attention()
|
/aosp_15_r20/external/pytorch/aten/src/ATen/functorch/ |
H A D | BatchRulesLinearAlgebra.cpp | 495 bool is_causal, in _scaled_dot_product_flash_attention_batch_rule() 545 bool is_causal, in _scaled_dot_product_efficient_attention_batch_rule() 586 bool is_causal, in _scaled_dot_product_cudnn_attention_batch_rule()
|
/aosp_15_r20/external/pytorch/torch/nn/attention/ |
H A D | _utils.py | 48 is_causal=False, argument
|
/aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/impl/ |
H A D | SDPA.cpp | 192 const ValueRef is_causal = args[arg_idx++]; in sdpa_with_kv_cache_impl() local
|
/aosp_15_r20/external/pytorch/test/ |
H A D | test_transformers.py | 969 …def test_scaled_dot_product_attention(self, device, input_dim, attn_mask_dim, is_causal, dropout_p… argument 1703 def _get_block_size_n(device, head_dim, is_dropout, is_causal): argument
|
H A D | test_decomp.py | 1188 self, query_layer, key_layer, value_layer, mask=None, is_causal=True argument
|
/aosp_15_r20/external/pytorch/torch/csrc/ |
H A D | Module.cpp | 1996 bool enable_gqa) { in initModule()
|