1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/prim_ops/et_copy_index.h>
10 #include <executorch/kernels/prim_ops/et_view.h>
11 #include <executorch/runtime/core/evalue.h>
12 #include <executorch/runtime/kernel/kernel_includes.h>
13 #include <executorch/runtime/kernel/operator_registry.h>
14
15 #include <cmath>
16
17 using torch::executor::function::et_copy_index;
18
19 namespace torch {
20 namespace executor {
21 namespace function {
22
23 namespace {
24
25 #define __ET_PRIM_OP_ERROR_IMPL(a, b, context) \
26 else { \
27 ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag); \
28 }
29
30 // TODO Fail using runtime context
31 #define __NUMBER_ET_PRIM_OP_IMPL(operator, stack, context) \
32 (void)context; \
33 EValue& a = *stack[0]; \
34 EValue& b = *stack[1]; \
35 EValue& out = *stack[2]; \
36 if (a.isInt() && b.isInt()) { \
37 out = EValue(a.toInt() operator b.toInt()); \
38 } else if (a.isDouble() && b.isDouble()) { \
39 out = EValue(a.toDouble() operator b.toDouble()); \
40 } else if (a.isInt() && b.isDouble()) { \
41 out = EValue(a.toInt() operator b.toDouble()); \
42 } else if (a.isDouble() && b.isInt()) { \
43 out = EValue(a.toDouble() operator b.toInt()); \
44 }
45
46 #define ALGEBRA_ET_PRIM_OP(operator, stack, context) \
47 __NUMBER_ET_PRIM_OP_IMPL(operator, stack, context) \
48 __ET_PRIM_OP_ERROR_IMPL(a, b, context)
49
50 #define BOOLEAN_ET_PRIM_OP(operator, stack, context) \
51 __NUMBER_ET_PRIM_OP_IMPL(operator, stack, context) \
52 else if (a.isBool() && b.isBool()) { \
53 out = EValue(a.toBool() operator b.toBool()); \
54 } \
55 __ET_PRIM_OP_ERROR_IMPL(a, b, context)
56
floor_div_double(double a,double b,EValue & out)57 void floor_div_double(double a, double b, EValue& out) {
58 if (b == 0) {
59 out = EValue(std::signbit(a) ? -INFINITY : INFINITY);
60 return;
61 }
62 const auto mod = std::fmod(a, b);
63 auto div = (a - mod) / b;
64 if ((mod != 0) && std::signbit(b) != std::signbit(mod)) {
65 out = EValue(div - 1);
66 return;
67 }
68 out = EValue(div);
69 }
70
71 static Kernel prim_ops[] = {
72 // aten::sym_size.int(Tensor self, int dim) -> SymInt
73 Kernel(
74 "aten::sym_size.int",
__anon9c72e3c90202() 75 [](KernelRuntimeContext& context, EValue** stack) {
76 (void)context;
77 EValue& self = *stack[0];
78 EValue& dim = *stack[1];
79 EValue& out = *stack[2];
80 exec_aten::Tensor self_tensor = self.to<exec_aten::Tensor>();
81 int64_t dim_val = dim.to<int64_t>();
82 int64_t size = self_tensor.size(dim_val);
83 out = EValue(size);
84 }),
85 // aten::_local_scalar_dense(Tensor self) -> Scalar
86 Kernel(
87 "aten::_local_scalar_dense",
__anon9c72e3c90302() 88 [](KernelRuntimeContext& context, EValue** stack) {
89 (void)context;
90 EValue& self = *stack[0];
91 EValue& out = *stack[1];
92 exec_aten::Tensor self_tensor = self.to<exec_aten::Tensor>();
93 ET_SWITCH_REAL_TYPES_AND(
94 Bool,
95 self_tensor.scalar_type(),
96 context,
97 "_local_scalar_dense",
98 CTYPE,
99 [&]() {
100 out = EValue(Scalar(self_tensor.const_data_ptr<CTYPE>()[0]));
101 });
102 }),
103 // aten::sym_numel(Tensor self) -> SymInt
104 Kernel(
105 "aten::sym_numel",
__anon9c72e3c90502() 106 [](KernelRuntimeContext& context, EValue** stack) {
107 (void)context;
108 EValue& self = *stack[0];
109 EValue& out = *stack[1];
110 exec_aten::Tensor self_tensor = self.to<exec_aten::Tensor>();
111 int64_t numel = self_tensor.numel();
112 out = EValue(numel);
113 }),
114 // executorch_prim::add.Scalar(Scalar, Scalar) -> Scalar
115 Kernel(
116 "executorch_prim::add.Scalar",
__anon9c72e3c90602() 117 [](KernelRuntimeContext& context, EValue** stack) {
118 (void)context;
119 ALGEBRA_ET_PRIM_OP(+, stack, context);
120 }),
121
122 // executorch_prim::sub.Scalar(Scalar, Scalar) -> Scalar
123 Kernel(
124 "executorch_prim::sub.Scalar",
__anon9c72e3c90702() 125 [](KernelRuntimeContext& context, EValue** stack) {
126 ALGEBRA_ET_PRIM_OP(-, stack, context);
127 }),
128
129 // executorch_prim::mul.Scalar(Scalar, Scalar) -> Scalar
130 Kernel(
131 "executorch_prim::mul.Scalar",
__anon9c72e3c90802() 132 [](KernelRuntimeContext& context, EValue** stack) {
133 ALGEBRA_ET_PRIM_OP(*, stack, context);
134 }),
135
136 /**
137 * Python's __floordiv__ operator is more complicated than just floor(a /
138 * b). It aims to maintain the property: a == (a // b) * b + remainder(a, b)
139 * which can otherwise fail due to rounding errors in the remainder.
140 * So, instead it is calculated as: a // b = (a - remainder(a, b)) / b
141 * With some additional fix-ups added to the result.
142 *
143 * executorch_prim::floordiv.Scalar(Scalar, Scalar) -> Scalar
144 */
145 Kernel(
146 "executorch_prim::floordiv.Scalar",
__anon9c72e3c90902() 147 [](KernelRuntimeContext& context, EValue** stack) {
148 (void)context;
149 EValue& a = *stack[0];
150 EValue& b = *stack[1];
151 EValue& out = *stack[2];
152 if (a.isInt() && b.isInt()) {
153 const int64_t quot = a.toInt() / b.toInt();
154 if ((a.toInt() < 0) == (b.toInt() < 0)) {
155 out = EValue(quot);
156 return;
157 }
158 const int64_t rem = a.toInt() % b.toInt();
159 out = EValue(rem ? quot - 1 : quot);
160 return;
161 } else if (a.isDouble() && b.isDouble()) {
162 floor_div_double(a.toDouble(), b.toDouble(), out);
163 } else if (a.isInt() && b.isDouble()) {
164 floor_div_double(static_cast<double>(a.toInt()), b.toDouble(), out);
165 } else if (a.isDouble() && b.isInt()) {
166 floor_div_double(a.toDouble(), static_cast<double>(b.toInt()), out);
167 } else {
168 // TODO Fail using runtime context
169 ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag);
170 }
171 }),
172
173 // executorch_prim::floordiv.Scalar(Scalar, Scalar) -> Scalar
174 Kernel(
175 "executorch_prim::truediv.Scalar",
__anon9c72e3c90a02() 176 [](KernelRuntimeContext& context, EValue** stack) {
177 // can't use macro because of custom casting behavior
178 (void)context;
179 EValue& a = *stack[0];
180 EValue& b = *stack[1];
181 EValue& out = *stack[2];
182 if (a.isInt() && b.isInt()) {
183 out = EValue(
184 static_cast<double>(a.toInt()) /
185 static_cast<double>(b.toInt()));
186 } else if (a.isDouble() && b.isDouble()) {
187 out = EValue(a.toDouble() / b.toDouble());
188 } else if (a.isInt() && b.isDouble()) {
189 out = EValue(a.toInt() / b.toDouble());
190 } else if (a.isDouble() && b.isInt()) {
191 out = EValue(a.toDouble() / b.toInt());
192 } else {
193 // TODO Fail using runtime context
194 ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag);
195 }
196 }),
197
198 // executorch_prim::sym_float.Scalar(Scalar) -> Scalar
199 Kernel(
200 "executorch_prim::sym_float.Scalar",
__anon9c72e3c90b02() 201 [](KernelRuntimeContext& context, EValue** stack) {
202 // can't use macro because of custom casting behavior
203 // TODO: Now that we are reliably generating conversion operators,
204 // we can remove the mixed type handling for other operators
205 (void)context;
206 EValue& a = *stack[0];
207 EValue& out = *stack[1];
208 if (a.isInt()) {
209 out = EValue(static_cast<double>(a.toInt()));
210 } else if (a.isDouble()) {
211 // TODO: This should be impossible
212 out = EValue(a.toDouble());
213 } else {
214 // TODO Fail using runtime context
215 ET_CHECK_MSG(false, "%zu", (size_t)a.tag);
216 }
217 }),
218
219 // executorch_prim::eq.Scalar(Scalar, Scalar) -> bool
220 Kernel(
221 "executorch_prim::eq.Scalar",
__anon9c72e3c90c02() 222 [](KernelRuntimeContext& context, EValue** stack) {
223 BOOLEAN_ET_PRIM_OP(==, stack, context);
224 }),
225
226 // executorch_prim::gt.Scalar(Scalar, Scalar) -> bool
227 Kernel(
228 "executorch_prim::gt.Scalar",
__anon9c72e3c90d02() 229 [](KernelRuntimeContext& context, EValue** stack) {
230 BOOLEAN_ET_PRIM_OP(>, stack, context);
231 }),
232
233 // executorch_prim::lt.Scalar(Scalar, Scalar) -> bool
234 Kernel(
235 "executorch_prim::lt.Scalar",
__anon9c72e3c90e02() 236 [](KernelRuntimeContext& context, EValue** stack) {
237 BOOLEAN_ET_PRIM_OP(<, stack, context);
238 }),
239
240 // executorch_prim::ge.Scalar(Scalar, Scalar) -> bool
241 Kernel(
242 "executorch_prim::ge.Scalar",
__anon9c72e3c90f02() 243 [](KernelRuntimeContext& context, EValue** stack) {
244 BOOLEAN_ET_PRIM_OP(>=, stack, context);
245 }),
246
247 // executorch_prim::le.Scalar(Scalar, Scalar) -> bool
248 Kernel(
249 "executorch_prim::le.Scalar",
__anon9c72e3c91002() 250 [](KernelRuntimeContext& context, EValue** stack) {
251 BOOLEAN_ET_PRIM_OP(<=, stack, context);
252 }),
253 // executorch_prim::neg.Scalar(Scalar) -> Scalar
254 Kernel(
255 "executorch_prim::neg.Scalar",
__anon9c72e3c91102() 256 [](KernelRuntimeContext& context, EValue** stack) {
257 (void)context;
258 EValue& a = *stack[0];
259 EValue& out = *stack[1];
260 if (a.isInt()) {
261 out = EValue(-a.toInt());
262 } else if (a.isDouble()) {
263 out = EValue(-a.toDouble());
264 } else {
265 // TODO Fail using runtime context
266 ET_CHECK_MSG(false, "%zu", (size_t)a.tag);
267 }
268 }),
269
270 // executorch_prim::floordiv.int(int, int) -> int
271 Kernel(
272 "executorch_prim::floordiv.int",
__anon9c72e3c91202() 273 [](KernelRuntimeContext& context, EValue** stack) {
274 (void)context;
275 EValue& a = *stack[0];
276 EValue& b = *stack[1];
277 EValue& out = *stack[2];
278 out = EValue(a.toInt() / b.toInt());
279 }),
280
281 // executorch_prim::mod.int(int, int) -> int
282 Kernel(
283 "executorch_prim::mod.int",
__anon9c72e3c91302() 284 [](KernelRuntimeContext& context, EValue** stack) {
285 (void)context;
286 EValue& a = *stack[0];
287 EValue& b = *stack[1];
288 EValue& out = *stack[2];
289 out = EValue(a.toInt() % b.toInt());
290 }),
291
292 // executorch_prim::mod.Scalar(Scalar, Scalar) -> Scalar
293 Kernel(
294 "executorch_prim::mod.Scalar",
__anon9c72e3c91402() 295 [](KernelRuntimeContext& context, EValue** stack) {
296 (void)context;
297 EValue& a = *stack[0];
298 EValue& b = *stack[1];
299 EValue& out = *stack[2];
300 if (a.isInt() && b.isInt()) {
301 out = EValue(a.toInt() % b.toInt());
302 } else {
303 ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag);
304 }
305 }),
306
307 // ceil.Scalar(Scalar a) -> Scalar
308 Kernel(
309 "executorch_prim::ceil.Scalar",
__anon9c72e3c91502() 310 [](KernelRuntimeContext& context, EValue** stack) {
311 (void)context;
312 EValue& a = *stack[0];
313 EValue& out = *stack[1];
314 if (a.isDouble()) {
315 out = EValue(static_cast<int64_t>(ceil(a.toDouble())));
316 } else {
317 ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag);
318 }
319 }),
320
321 // round.Scalar(Scalar a) -> Scalar
322 Kernel(
323 "executorch_prim::round.Scalar",
__anon9c72e3c91602() 324 [](KernelRuntimeContext& context, EValue** stack) {
325 (void)context;
326 EValue& a = *stack[0];
327 EValue& out = *stack[1];
328 if (a.isDouble()) {
329 // Round half to even to match Python round(). Need an explicit
330 // implementation as not all platforms support fenv rounding modes.
331 // See
332 // https://codeyarns.com/tech/2018-08-17-how-to-round-half-to-even.html
333 const auto val = a.toDouble();
334 const auto r = round(val);
335 const auto d = r - val;
336 auto res = 0.0;
337
338 if (std::abs(d) != 0.5) {
339 res = r;
340 } else if (fmod(r, 2.0) == 0.0) {
341 res = r;
342 } else {
343 res = val - d;
344 }
345
346 out = EValue(static_cast<int64_t>(res));
347 } else {
348 ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag);
349 }
350 }),
351
352 // trunc.Scalar(Scalar a) -> Scalar
353 Kernel(
354 "executorch_prim::trunc.Scalar",
__anon9c72e3c91702() 355 [](KernelRuntimeContext& context, EValue** stack) {
356 (void)context;
357 EValue& a = *stack[0];
358 EValue& out = *stack[1];
359 if (a.isDouble()) {
360 out = EValue(static_cast<int64_t>(trunc(a.toDouble())));
361 } else {
362 ET_CHECK_MSG(false, "%zu", (size_t)a.tag);
363 }
364 }),
365
366 // executorch_prim::et_copy_index.tensor(tensor, tensor) -> tensor
367 Kernel(
368 "executorch_prim::et_copy_index.tensor",
__anon9c72e3c91802() 369 [](KernelRuntimeContext& context, EValue** stack) {
370 et_copy_index(context, stack);
371 }),
372 // executorch_prim::et_view.default(Tensor, int[]) -> Tensor
373 Kernel(
374 "executorch_prim::et_view.default",
__anon9c72e3c91902() 375 [](KernelRuntimeContext& context, EValue** stack) {
376 et_view(context, stack);
377 }),
378
379 };
380
381 executorch::runtime::Span<const executorch::runtime::Kernel> kernel_span(
382 prim_ops,
383 prim_ops + sizeof(prim_ops) / sizeof(Kernel));
384
385 // Return value not used. Keep the static variable assignment to register
386 // operators in static initialization time.
387 auto success_with_kernel_reg =
388 executorch::runtime::register_kernels(kernel_span);
389
390 } // namespace
391 } // namespace function
392 } // namespace executor
393 } // namespace torch
394