Searched defs:attn_bias (Results 1 – 10 of 10) sorted by relevance
/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/ |
H A D | attention_backward.cu | 176 const Tensor& attn_bias, in _scaled_dot_product_cudnn_attention_backward_cuda() 789 const at::Tensor& attn_bias, in _scaled_dot_product_efficient_attention_backward_cuda()
|
H A D | attention.cu | 793 const std::optional<at::Tensor>& attn_bias, in _scaled_dot_product_efficient_attention_cuda()
|
/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/ |
H A D | attention.cpp | 550 at::Tensor pad_bias(const at::Tensor& attn_bias) { in pad_bias() 579 at::Tensor pad_last_dim(const at::Tensor& attn_bias) { in pad_last_dim()
|
/aosp_15_r20/external/pytorch/test/cpp_extensions/ |
H A D | open_registration_extension.cpp | 449 const std::optional<at::Tensor> & attn_bias, in custom_scaled_dot_product_fused_attention_overrideable() 477 const at::Tensor & attn_bias, in custom_scaled_dot_product_fused_attention_overrideable_backward()
|
/aosp_15_r20/external/pytorch/aten/src/ATen/functorch/ |
H A D | BatchRulesLinearAlgebra.cpp | 542 const std::optional<Tensor>& attn_bias, optional<int64_t> attn_bias_bdim, in _scaled_dot_product_efficient_attention_batch_rule() 583 const std::optional<Tensor>& attn_bias, std::optional<int64_t> attn_bias_bdim, in _scaled_dot_product_cudnn_attention_batch_rule()
|
/aosp_15_r20/external/pytorch/aten/src/ATen/native/nested/cuda/ |
H A D | NestedTensorTransformerFunctions.cpp | 286 const std::optional<at::Tensor>& attn_bias, in _scaled_dot_product_efficient_attention_nestedtensor_cuda()
|
/aosp_15_r20/external/pytorch/test/ |
H A D | test_transformers.py | 2019 attn_bias=None, argument 3422 attn_bias=None, argument
|
/aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_torch/ |
H A D | shim_common.cpp | 604 AtenTensorHandle attn_bias, // optional argument in aoti_torch__scaled_dot_product_efficient_attention()
|
/aosp_15_r20/external/pytorch/test/inductor/ |
H A D | test_aot_inductor.py | 2847 def forward(self, q, k, v, attn_bias): argument
|
H A D | test_torchinductor.py | 9655 def fn(q, k, v, attn_bias, compute_log_sumexp): argument
|