Home
last modified time | relevance | path

Searched defs:attn_mask (Results 1 – 24 of 24) sorted by relevance

/aosp_15_r20/external/pytorch/torch/_inductor/fx_passes/
H A Dfuse_attention.py119 def _sfdp_pattern_5(query, key, value, attn_mask): argument
127 def _sfdp_replacement_5(query, key, value, attn_mask): argument
139 def _sfdp_pattern_6(query, key, value, attn_mask, dropout_p): argument
147 def _sfdp_replacement_6(query, key, value, attn_mask, dropout_p): argument
338 def _sfdp_pattern_14(query, key, value, attn_mask, inv_scale): argument
351 def _sfdp_replacement_14(query, key, value, attn_mask, inv_scale): argument
364 def _sfdp_pattern_15(query, key, value, attn_mask, inv_scale): argument
380 def _sfdp_replacement_15(query, key, value, attn_mask, inv_scale): argument
401 def _sfdp_pattern_16(query, key, value, attn_mask, inv_scale, dropout_p): argument
418 def _sfdp_replacement_16(query, key, value, attn_mask, inv_scale, dropout_p): argument
[all …]
/aosp_15_r20/external/executorch/extension/llm/custom_ops/
H A Dop_sdpa_test.cpp26 const exec_aten::optional<exec_aten::Tensor>& attn_mask, in op_scaled_dot_product_attention()
97 exec_aten::optional<exec_aten::Tensor> attn_mask; in TEST() local
138 exec_aten::optional<exec_aten::Tensor> attn_mask = in TEST() local
170 exec_aten::optional<exec_aten::Tensor> attn_mask; in TEST() local
238 exec_aten::optional<exec_aten::Tensor> attn_mask; in TEST() local
314 exec_aten::optional<exec_aten::Tensor> attn_mask; in TEST() local
496 exec_aten::optional<exec_aten::Tensor> attn_mask = in TEST() local
H A Dop_sdpa_with_kv_cache_test.cpp29 const exec_aten::optional<exec_aten::Tensor>& attn_mask, in op_sdpa_with_kv_cache()
145 exec_aten::optional<exec_aten::Tensor> attn_mask; in TEST() local
387 exec_aten::optional<exec_aten::Tensor> attn_mask; in TEST() local
584 exec_aten::Tensor attn_mask = tfFloat.make({1, 1}, {0}); in TEST() local
835 exec_aten::optional<exec_aten::Tensor> attn_mask; in TEST() local
H A Dop_sdpa_aot.cpp30 const optional<Tensor> attn_mask, in sdpa_with_kv_cache_out_no_context()
63 const std::optional<at::Tensor> attn_mask, in sdpa_with_kv_cache_aten()
92 const optional<Tensor> attn_mask, in custom_sdpa_out_no_context()
119 const std::optional<at::Tensor> attn_mask, in custom_sdpa_aten()
H A Dsdpa_with_kv_cache.py47 attn_mask, argument
114 attn_mask=None, argument
150 attn_mask=None, argument
H A Dop_sdpa.cpp224 const optional<Tensor>& attn_mask, in cpu_flash_attention()
596 const optional<Tensor>& attn_mask) { in validate_flash_attention_args()
763 const optional<Tensor>& attn_mask, in flash_attention_kernel_out()
846 const optional<Tensor>& attn_mask, in custom_sdpa_out()
995 const optional<Tensor>& attn_mask, in sdpa_with_kv_cache_out()
H A Dtest_sdpa_with_kv_cache.py17 def _sdpa_with_kv_cache_ref(q, k, v, k_cache, v_cache, attn_mask, start_pos, seq_len): argument
/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/
H A Dattention.cpp110 std::optional<Tensor> attn_mask, in masked_softmax()
516 std::optional<Tensor> convert_boolean_attn_mask(const std::optional<Tensor>& attn_mask, caffe2::Typ… in convert_boolean_attn_mask()
664 std::optional<Tensor> attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype()); in scaled_dot_product_attention() local
726 auto attn_mask = attn_mask_; in _scaled_dot_product_attention_math() local
777 const std::optional<Tensor>& attn_mask, in _scaled_dot_product_flash_attention_cpu()
825 const std::optional<Tensor>& attn_mask, in _scaled_dot_product_flash_attention_cpu_backward()
H A Dsdp_utils_cpp.h48 std::optional<at::Tensor> attn_mask; member
279 auto attn_mask = params.attn_mask; in check_attn_mask_shape() local
/aosp_15_r20/external/executorch/backends/vulkan/test/op_tests/
H A Dsdpa_test.cpp42 const optional<Tensor> attn_mask, in sdpa_with_kv_cache_out_no_context()
75 const std::optional<at::Tensor> attn_mask, in sdpa_with_kv_cache_aten()
110 const at::Tensor& attn_mask, in convert_boolean_attn_mask()
166 at::Tensor attn_mask = in sdpa_reference_impl() local
/aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/
H A DFlashAttentionKernel.cpp149 Tensor& attn_mask, in reshape_attn_mask_to_4d()
182 std::optional<Tensor> attn_mask, in cpu_flash_attention()
432 std::optional<Tensor> attn_mask, in cpu_flash_attention_backward()
737 std::optional<Tensor> attn_mask, in flash_attention_kernel_impl()
788 std::optional<Tensor> attn_mask, in flash_attention_backward_kernel_impl()
/aosp_15_r20/external/pytorch/torch/nested/_internal/
H A Dsdpa.py264 def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal): argument
/aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/impl/
H A DSDPA.cpp190 const ValueRef attn_mask = args[arg_idx++]; in sdpa_with_kv_cache_impl() local
/aosp_15_r20/external/pytorch/aten/src/ATen/native/nested/
H A DNestedTensorTransformerFunctions.cpp186 std::optional<Tensor> attn_mask; in NestedTensor_softmax_dropout_cuda() local
/aosp_15_r20/external/pytorch/benchmarks/transformer/
H A Dscore_mod.py148 def eager_sdpa(query, key, value, attn_mask): argument
/aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/
H A Dactivation.cpp446 const Tensor& attn_mask, in forward()
/aosp_15_r20/external/pytorch/test/
H A Dtest_transformers.py974 attn_mask=None, argument
1092 …def _test_fastpath(model, key_padding_mask, mock_return_value, attn_mask=None, nested_tensors=True… argument
H A Dtest_jit.py14924 attn_mask=None # type: Optional[Tensor] argument
/aosp_15_r20/external/pytorch/test/cpp_extensions/
H A Dopen_registration_extension.cpp131const std::optional<at::Tensor> & attn_mask, double dropout_p, bool is_causal, std::optional<doubl… in _fused_sdp_choice_privateuse1()
/aosp_15_r20/external/pytorch/test/inductor/
H A Dtest_fused_attention.py551 def forward(self, query, key, value, attn_mask) -> torch.Tensor: argument
/aosp_15_r20/external/pytorch/torch/csrc/
H A DModule.cpp1996 bool enable_gqa) { in initModule()
/aosp_15_r20/external/executorch/backends/qualcomm/tests/
H A Dmodels.py862 def forward(self, query_layer, key_layer, value_layer, attn_mask): argument
/aosp_15_r20/external/pytorch/test/dynamo/
H A Dtest_repros.py888 def _sa_block(self, x, attn_mask, key_padding_mask): argument
/aosp_15_r20/external/pytorch/test/cpp/api/
H A Dmodules.cpp3556 auto attn_mask = torch::randint(0, 2, {1, seq_len}, torch::kFloat); in _multihead_attn_test_helper() local