xref: /aosp_15_r20/external/executorch/extension/llm/custom_ops/op_sdpa.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/runtime/kernel/kernel_includes.h>
12 
13 namespace torch {
14 namespace executor {
15 
16 namespace native {
17 
18 Tensor& sdpa_with_kv_cache_out(
19     KernelRuntimeContext& ctx,
20     const Tensor& q_projected,
21     const Tensor& k_projected,
22     const Tensor& v_projected,
23     Tensor& key_cache,
24     Tensor& value_cache,
25     const int64_t start_pos,
26     const int64_t seq_len,
27     const optional<Tensor>& attn_mask,
28     const double dropout_p,
29     const bool is_causal,
30     // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
31     const optional<double> scale,
32     Tensor& output);
33 
34 Tensor& custom_sdpa_out(
35     RuntimeContext& ctx,
36     const Tensor& q,
37     const Tensor& k,
38     const Tensor& v,
39     const int64_t start_pos,
40     const optional<Tensor>& attn_mask,
41     const double dropout_p,
42     const bool is_causal,
43     // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
44     const optional<double> scale,
45     Tensor& output);
46 
47 Tensor& flash_attention_kernel_out(
48     KernelRuntimeContext& ctx,
49     const Tensor& query,
50     const Tensor& key,
51     const Tensor& value,
52     const optional<Tensor>& attn_mask,
53     const double dropout_p,
54     const bool is_causal,
55     // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
56     const optional<double> scale,
57     Tensor& output);
58 
59 } // namespace native
60 } // namespace executor
61 } // namespace torch
62