xref: /aosp_15_r20/external/executorch/extension/llm/custom_ops/op_sdpa_aot.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
10 #include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
11 #include <executorch/extension/llm/custom_ops/op_sdpa.h>
12 #include <executorch/extension/llm/custom_ops/op_update_quantized_cache.h>
13 
14 #include <torch/library.h>
15 
16 namespace torch {
17 namespace executor {
18 
19 namespace native {
sdpa_with_kv_cache_out_no_context(const Tensor & q_projected,const Tensor & k_projected,const Tensor & v_projected,Tensor & key_cache,Tensor & value_cache,const int64_t start_pos,const int64_t seq_len,const optional<Tensor> attn_mask,const double dropout_p,const bool is_causal,const optional<double> scale,Tensor & output)20 Tensor& sdpa_with_kv_cache_out_no_context(
21     const Tensor& q_projected,
22     const Tensor& k_projected,
23     const Tensor& v_projected,
24     Tensor& key_cache,
25     Tensor& value_cache,
26     const int64_t start_pos,
27     const int64_t seq_len,
28     // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
29     // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
30     const optional<Tensor> attn_mask,
31     const double dropout_p,
32     const bool is_causal,
33     // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
34     const optional<double> scale,
35     Tensor& output) {
36   executorch::runtime::KernelRuntimeContext context{};
37   return torch::executor::native::sdpa_with_kv_cache_out(
38       context,
39       q_projected,
40       k_projected,
41       v_projected,
42       key_cache,
43       value_cache,
44       start_pos,
45       seq_len,
46       attn_mask,
47       dropout_p,
48       is_causal,
49       scale,
50       output);
51 }
52 
sdpa_with_kv_cache_aten(const at::Tensor & q_projected,const at::Tensor & k_projected,const at::Tensor & v_projected,at::Tensor & key_cache,at::Tensor & value_cache,const int64_t start_pos,const int64_t seq_len,const std::optional<at::Tensor> attn_mask,const double dropout_p,const bool is_causal,const std::optional<double> scale)53 at::Tensor sdpa_with_kv_cache_aten(
54     const at::Tensor& q_projected,
55     const at::Tensor& k_projected,
56     const at::Tensor& v_projected,
57     at::Tensor& key_cache,
58     at::Tensor& value_cache,
59     const int64_t start_pos,
60     const int64_t seq_len,
61     // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
62     // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
63     const std::optional<at::Tensor> attn_mask,
64     const double dropout_p,
65     const bool is_causal,
66     // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
67     const std::optional<double> scale) {
68   auto output = at::empty_like(q_projected);
69   WRAP_TO_ATEN(sdpa_with_kv_cache_out_no_context, 11)
70   (q_projected,
71    k_projected,
72    v_projected,
73    key_cache,
74    value_cache,
75    start_pos,
76    seq_len,
77    attn_mask,
78    dropout_p,
79    is_causal,
80    scale,
81    output);
82   return output;
83 }
84 
custom_sdpa_out_no_context(const Tensor & q,const Tensor & k,const Tensor & v,const int64_t start_pos,const optional<Tensor> attn_mask,const double dropout_p,const bool is_causal,const optional<double> scale,Tensor & output)85 Tensor& custom_sdpa_out_no_context(
86     const Tensor& q,
87     const Tensor& k,
88     const Tensor& v,
89     const int64_t start_pos,
90     // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
91     // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
92     const optional<Tensor> attn_mask,
93     const double dropout_p,
94     const bool is_causal,
95     // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
96     const optional<double> scale,
97     Tensor& output) {
98   exec_aten::RuntimeContext context{};
99   return torch::executor::native::custom_sdpa_out(
100       context,
101       q,
102       k,
103       v,
104       start_pos,
105       attn_mask,
106       dropout_p,
107       is_causal,
108       scale,
109       output);
110 }
111 
custom_sdpa_aten(const at::Tensor & q,const at::Tensor & k,const at::Tensor & v,const int64_t start_pos,const std::optional<at::Tensor> attn_mask,const double dropout_p,const bool is_causal,const std::optional<double> scale)112 at::Tensor custom_sdpa_aten(
113     const at::Tensor& q,
114     const at::Tensor& k,
115     const at::Tensor& v,
116     const int64_t start_pos,
117     // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
118     // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
119     const std::optional<at::Tensor> attn_mask,
120     const double dropout_p,
121     const bool is_causal,
122     // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
123     const std::optional<double> scale) {
124   auto output = at::empty_like(q);
125   WRAP_TO_ATEN(custom_sdpa_out_no_context, 8)
126   (q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output);
127   return output;
128 }
129 
update_quantized_cache_out_no_context(const Tensor & value,Tensor & cache,const int64_t start_pos,Tensor & output)130 Tensor& update_quantized_cache_out_no_context(
131     const Tensor& value,
132     Tensor& cache,
133     const int64_t start_pos,
134     Tensor& output) {
135   exec_aten::RuntimeContext context{};
136   return torch::executor::native::update_quantized_cache_out(
137       context, value, cache, start_pos, output);
138 }
139 
update_quantized_cache_aten(const at::Tensor & value,at::Tensor & cache,const int64_t start_pos)140 at::Tensor update_quantized_cache_aten(
141     const at::Tensor& value,
142     at::Tensor& cache,
143     const int64_t start_pos) {
144   auto output = at::empty({1});
145   WRAP_TO_ATEN(update_quantized_cache_out_no_context, 3)
146   (value, cache, start_pos, output);
147   return output;
148 }
149 
150 } // namespace native
151 } // namespace executor
152 } // namespace torch
153 
TORCH_LIBRARY_FRAGMENT(llama,m)154 TORCH_LIBRARY_FRAGMENT(llama, m) {
155   m.def(
156       "sdpa_with_kv_cache(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
157       "Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
158       "float drpout_p=0.0, bool is_causal=False, float? scale=None) -> Tensor");
159   m.def(
160       "sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
161       "Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
162       "float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)");
163   m.def(
164       "custom_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
165       "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
166       "float? scale=None) -> Tensor");
167   m.def(
168       "custom_sdpa.out(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
169       "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
170       "float? scale=None, *, Tensor(a!) out) -> Tensor(a!)");
171   m.def(
172       "update_quantized_cache(Tensor value, Tensor(a!) cache, "
173       "SymInt start_pos) -> Tensor");
174   m.def(
175       "update_quantized_cache.out(Tensor value, Tensor(a!) cache, "
176       "SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)");
177 }
178 
179 // TODO: Rename this file to op_custom_ops_aot.cpp
TORCH_LIBRARY_IMPL(llama,CompositeExplicitAutograd,m)180 TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
181   m.impl(
182       "sdpa_with_kv_cache", torch::executor::native::sdpa_with_kv_cache_aten);
183   m.impl(
184       "sdpa_with_kv_cache.out",
185       WRAP_TO_ATEN(
186           torch::executor::native::sdpa_with_kv_cache_out_no_context, 11));
187   m.impl("custom_sdpa", torch::executor::native::custom_sdpa_aten);
188   m.impl(
189       "custom_sdpa.out",
190       WRAP_TO_ATEN(torch::executor::native::custom_sdpa_out_no_context, 8));
191   m.impl(
192       "update_quantized_cache",
193       torch::executor::native::update_quantized_cache_aten);
194   m.impl(
195       "update_quantized_cache.out",
196       WRAP_TO_ATEN(
197           torch::executor::native::update_quantized_cache_out_no_context, 3));
198 }
199