1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
10 #include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
11 #include <executorch/extension/llm/custom_ops/op_sdpa.h>
12 #include <executorch/extension/llm/custom_ops/op_update_quantized_cache.h>
13
14 #include <torch/library.h>
15
16 namespace torch {
17 namespace executor {
18
19 namespace native {
sdpa_with_kv_cache_out_no_context(const Tensor & q_projected,const Tensor & k_projected,const Tensor & v_projected,Tensor & key_cache,Tensor & value_cache,const int64_t start_pos,const int64_t seq_len,const optional<Tensor> attn_mask,const double dropout_p,const bool is_causal,const optional<double> scale,Tensor & output)20 Tensor& sdpa_with_kv_cache_out_no_context(
21 const Tensor& q_projected,
22 const Tensor& k_projected,
23 const Tensor& v_projected,
24 Tensor& key_cache,
25 Tensor& value_cache,
26 const int64_t start_pos,
27 const int64_t seq_len,
28 // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
29 // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
30 const optional<Tensor> attn_mask,
31 const double dropout_p,
32 const bool is_causal,
33 // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
34 const optional<double> scale,
35 Tensor& output) {
36 executorch::runtime::KernelRuntimeContext context{};
37 return torch::executor::native::sdpa_with_kv_cache_out(
38 context,
39 q_projected,
40 k_projected,
41 v_projected,
42 key_cache,
43 value_cache,
44 start_pos,
45 seq_len,
46 attn_mask,
47 dropout_p,
48 is_causal,
49 scale,
50 output);
51 }
52
sdpa_with_kv_cache_aten(const at::Tensor & q_projected,const at::Tensor & k_projected,const at::Tensor & v_projected,at::Tensor & key_cache,at::Tensor & value_cache,const int64_t start_pos,const int64_t seq_len,const std::optional<at::Tensor> attn_mask,const double dropout_p,const bool is_causal,const std::optional<double> scale)53 at::Tensor sdpa_with_kv_cache_aten(
54 const at::Tensor& q_projected,
55 const at::Tensor& k_projected,
56 const at::Tensor& v_projected,
57 at::Tensor& key_cache,
58 at::Tensor& value_cache,
59 const int64_t start_pos,
60 const int64_t seq_len,
61 // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
62 // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
63 const std::optional<at::Tensor> attn_mask,
64 const double dropout_p,
65 const bool is_causal,
66 // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
67 const std::optional<double> scale) {
68 auto output = at::empty_like(q_projected);
69 WRAP_TO_ATEN(sdpa_with_kv_cache_out_no_context, 11)
70 (q_projected,
71 k_projected,
72 v_projected,
73 key_cache,
74 value_cache,
75 start_pos,
76 seq_len,
77 attn_mask,
78 dropout_p,
79 is_causal,
80 scale,
81 output);
82 return output;
83 }
84
custom_sdpa_out_no_context(const Tensor & q,const Tensor & k,const Tensor & v,const int64_t start_pos,const optional<Tensor> attn_mask,const double dropout_p,const bool is_causal,const optional<double> scale,Tensor & output)85 Tensor& custom_sdpa_out_no_context(
86 const Tensor& q,
87 const Tensor& k,
88 const Tensor& v,
89 const int64_t start_pos,
90 // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
91 // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
92 const optional<Tensor> attn_mask,
93 const double dropout_p,
94 const bool is_causal,
95 // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
96 const optional<double> scale,
97 Tensor& output) {
98 exec_aten::RuntimeContext context{};
99 return torch::executor::native::custom_sdpa_out(
100 context,
101 q,
102 k,
103 v,
104 start_pos,
105 attn_mask,
106 dropout_p,
107 is_causal,
108 scale,
109 output);
110 }
111
custom_sdpa_aten(const at::Tensor & q,const at::Tensor & k,const at::Tensor & v,const int64_t start_pos,const std::optional<at::Tensor> attn_mask,const double dropout_p,const bool is_causal,const std::optional<double> scale)112 at::Tensor custom_sdpa_aten(
113 const at::Tensor& q,
114 const at::Tensor& k,
115 const at::Tensor& v,
116 const int64_t start_pos,
117 // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
118 // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
119 const std::optional<at::Tensor> attn_mask,
120 const double dropout_p,
121 const bool is_causal,
122 // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
123 const std::optional<double> scale) {
124 auto output = at::empty_like(q);
125 WRAP_TO_ATEN(custom_sdpa_out_no_context, 8)
126 (q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output);
127 return output;
128 }
129
update_quantized_cache_out_no_context(const Tensor & value,Tensor & cache,const int64_t start_pos,Tensor & output)130 Tensor& update_quantized_cache_out_no_context(
131 const Tensor& value,
132 Tensor& cache,
133 const int64_t start_pos,
134 Tensor& output) {
135 exec_aten::RuntimeContext context{};
136 return torch::executor::native::update_quantized_cache_out(
137 context, value, cache, start_pos, output);
138 }
139
update_quantized_cache_aten(const at::Tensor & value,at::Tensor & cache,const int64_t start_pos)140 at::Tensor update_quantized_cache_aten(
141 const at::Tensor& value,
142 at::Tensor& cache,
143 const int64_t start_pos) {
144 auto output = at::empty({1});
145 WRAP_TO_ATEN(update_quantized_cache_out_no_context, 3)
146 (value, cache, start_pos, output);
147 return output;
148 }
149
150 } // namespace native
151 } // namespace executor
152 } // namespace torch
153
TORCH_LIBRARY_FRAGMENT(llama,m)154 TORCH_LIBRARY_FRAGMENT(llama, m) {
155 m.def(
156 "sdpa_with_kv_cache(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
157 "Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
158 "float drpout_p=0.0, bool is_causal=False, float? scale=None) -> Tensor");
159 m.def(
160 "sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
161 "Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
162 "float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)");
163 m.def(
164 "custom_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
165 "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
166 "float? scale=None) -> Tensor");
167 m.def(
168 "custom_sdpa.out(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
169 "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
170 "float? scale=None, *, Tensor(a!) out) -> Tensor(a!)");
171 m.def(
172 "update_quantized_cache(Tensor value, Tensor(a!) cache, "
173 "SymInt start_pos) -> Tensor");
174 m.def(
175 "update_quantized_cache.out(Tensor value, Tensor(a!) cache, "
176 "SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)");
177 }
178
179 // TODO: Rename this file to op_custom_ops_aot.cpp
TORCH_LIBRARY_IMPL(llama,CompositeExplicitAutograd,m)180 TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
181 m.impl(
182 "sdpa_with_kv_cache", torch::executor::native::sdpa_with_kv_cache_aten);
183 m.impl(
184 "sdpa_with_kv_cache.out",
185 WRAP_TO_ATEN(
186 torch::executor::native::sdpa_with_kv_cache_out_no_context, 11));
187 m.impl("custom_sdpa", torch::executor::native::custom_sdpa_aten);
188 m.impl(
189 "custom_sdpa.out",
190 WRAP_TO_ATEN(torch::executor::native::custom_sdpa_out_no_context, 8));
191 m.impl(
192 "update_quantized_cache",
193 torch::executor::native::update_quantized_cache_aten);
194 m.impl(
195 "update_quantized_cache.out",
196 WRAP_TO_ATEN(
197 torch::executor::native::update_quantized_cache_out_no_context, 3));
198 }
199