xref: /aosp_15_r20/external/executorch/kernels/prim_ops/register_prim_ops.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/prim_ops/et_copy_index.h>
10 #include <executorch/kernels/prim_ops/et_view.h>
11 #include <executorch/runtime/core/evalue.h>
12 #include <executorch/runtime/kernel/kernel_includes.h>
13 #include <executorch/runtime/kernel/operator_registry.h>
14 
15 #include <cmath>
16 
17 using torch::executor::function::et_copy_index;
18 
19 namespace torch {
20 namespace executor {
21 namespace function {
22 
23 namespace {
24 
25 #define __ET_PRIM_OP_ERROR_IMPL(a, b, context)                     \
26   else {                                                           \
27     ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag); \
28   }
29 
30 // TODO Fail using runtime context
31 #define __NUMBER_ET_PRIM_OP_IMPL(operator, stack, context) \
32   (void)context;                                           \
33   EValue& a = *stack[0];                                   \
34   EValue& b = *stack[1];                                   \
35   EValue& out = *stack[2];                                 \
36   if (a.isInt() && b.isInt()) {                            \
37     out = EValue(a.toInt() operator b.toInt());            \
38   } else if (a.isDouble() && b.isDouble()) {               \
39     out = EValue(a.toDouble() operator b.toDouble());      \
40   } else if (a.isInt() && b.isDouble()) {                  \
41     out = EValue(a.toInt() operator b.toDouble());         \
42   } else if (a.isDouble() && b.isInt()) {                  \
43     out = EValue(a.toDouble() operator b.toInt());         \
44   }
45 
46 #define ALGEBRA_ET_PRIM_OP(operator, stack, context) \
47   __NUMBER_ET_PRIM_OP_IMPL(operator, stack, context) \
48   __ET_PRIM_OP_ERROR_IMPL(a, b, context)
49 
50 #define BOOLEAN_ET_PRIM_OP(operator, stack, context) \
51   __NUMBER_ET_PRIM_OP_IMPL(operator, stack, context) \
52   else if (a.isBool() && b.isBool()) {               \
53     out = EValue(a.toBool() operator b.toBool());    \
54   }                                                  \
55   __ET_PRIM_OP_ERROR_IMPL(a, b, context)
56 
floor_div_double(double a,double b,EValue & out)57 void floor_div_double(double a, double b, EValue& out) {
58   if (b == 0) {
59     out = EValue(std::signbit(a) ? -INFINITY : INFINITY);
60     return;
61   }
62   const auto mod = std::fmod(a, b);
63   auto div = (a - mod) / b;
64   if ((mod != 0) && std::signbit(b) != std::signbit(mod)) {
65     out = EValue(div - 1);
66     return;
67   }
68   out = EValue(div);
69 }
70 
71 static Kernel prim_ops[] = {
72     // aten::sym_size.int(Tensor self, int dim) -> SymInt
73     Kernel(
74         "aten::sym_size.int",
__anon9c72e3c90202() 75         [](KernelRuntimeContext& context, EValue** stack) {
76           (void)context;
77           EValue& self = *stack[0];
78           EValue& dim = *stack[1];
79           EValue& out = *stack[2];
80           exec_aten::Tensor self_tensor = self.to<exec_aten::Tensor>();
81           int64_t dim_val = dim.to<int64_t>();
82           int64_t size = self_tensor.size(dim_val);
83           out = EValue(size);
84         }),
85     // aten::_local_scalar_dense(Tensor self) -> Scalar
86     Kernel(
87         "aten::_local_scalar_dense",
__anon9c72e3c90302() 88         [](KernelRuntimeContext& context, EValue** stack) {
89           (void)context;
90           EValue& self = *stack[0];
91           EValue& out = *stack[1];
92           exec_aten::Tensor self_tensor = self.to<exec_aten::Tensor>();
93           ET_SWITCH_REAL_TYPES_AND(
94               Bool,
95               self_tensor.scalar_type(),
96               context,
97               "_local_scalar_dense",
98               CTYPE,
99               [&]() {
100                 out = EValue(Scalar(self_tensor.const_data_ptr<CTYPE>()[0]));
101               });
102         }),
103     // aten::sym_numel(Tensor self) -> SymInt
104     Kernel(
105         "aten::sym_numel",
__anon9c72e3c90502() 106         [](KernelRuntimeContext& context, EValue** stack) {
107           (void)context;
108           EValue& self = *stack[0];
109           EValue& out = *stack[1];
110           exec_aten::Tensor self_tensor = self.to<exec_aten::Tensor>();
111           int64_t numel = self_tensor.numel();
112           out = EValue(numel);
113         }),
114     // executorch_prim::add.Scalar(Scalar, Scalar) -> Scalar
115     Kernel(
116         "executorch_prim::add.Scalar",
__anon9c72e3c90602() 117         [](KernelRuntimeContext& context, EValue** stack) {
118           (void)context;
119           ALGEBRA_ET_PRIM_OP(+, stack, context);
120         }),
121 
122     // executorch_prim::sub.Scalar(Scalar, Scalar) -> Scalar
123     Kernel(
124         "executorch_prim::sub.Scalar",
__anon9c72e3c90702() 125         [](KernelRuntimeContext& context, EValue** stack) {
126           ALGEBRA_ET_PRIM_OP(-, stack, context);
127         }),
128 
129     // executorch_prim::mul.Scalar(Scalar, Scalar) -> Scalar
130     Kernel(
131         "executorch_prim::mul.Scalar",
__anon9c72e3c90802() 132         [](KernelRuntimeContext& context, EValue** stack) {
133           ALGEBRA_ET_PRIM_OP(*, stack, context);
134         }),
135 
136     /**
137      * Python's __floordiv__ operator is more complicated than just floor(a /
138      * b). It aims to maintain the property: a == (a // b) * b + remainder(a, b)
139      * which can otherwise fail due to rounding errors in the remainder.
140      * So, instead it is calculated as: a // b = (a - remainder(a, b)) / b
141      * With some additional fix-ups added to the result.
142      *
143      * executorch_prim::floordiv.Scalar(Scalar, Scalar) -> Scalar
144      */
145     Kernel(
146         "executorch_prim::floordiv.Scalar",
__anon9c72e3c90902() 147         [](KernelRuntimeContext& context, EValue** stack) {
148           (void)context;
149           EValue& a = *stack[0];
150           EValue& b = *stack[1];
151           EValue& out = *stack[2];
152           if (a.isInt() && b.isInt()) {
153             const int64_t quot = a.toInt() / b.toInt();
154             if ((a.toInt() < 0) == (b.toInt() < 0)) {
155               out = EValue(quot);
156               return;
157             }
158             const int64_t rem = a.toInt() % b.toInt();
159             out = EValue(rem ? quot - 1 : quot);
160             return;
161           } else if (a.isDouble() && b.isDouble()) {
162             floor_div_double(a.toDouble(), b.toDouble(), out);
163           } else if (a.isInt() && b.isDouble()) {
164             floor_div_double(static_cast<double>(a.toInt()), b.toDouble(), out);
165           } else if (a.isDouble() && b.isInt()) {
166             floor_div_double(a.toDouble(), static_cast<double>(b.toInt()), out);
167           } else {
168             // TODO Fail using runtime context
169             ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag);
170           }
171         }),
172 
173     // executorch_prim::floordiv.Scalar(Scalar, Scalar) -> Scalar
174     Kernel(
175         "executorch_prim::truediv.Scalar",
__anon9c72e3c90a02() 176         [](KernelRuntimeContext& context, EValue** stack) {
177           // can't use macro because of custom casting behavior
178           (void)context;
179           EValue& a = *stack[0];
180           EValue& b = *stack[1];
181           EValue& out = *stack[2];
182           if (a.isInt() && b.isInt()) {
183             out = EValue(
184                 static_cast<double>(a.toInt()) /
185                 static_cast<double>(b.toInt()));
186           } else if (a.isDouble() && b.isDouble()) {
187             out = EValue(a.toDouble() / b.toDouble());
188           } else if (a.isInt() && b.isDouble()) {
189             out = EValue(a.toInt() / b.toDouble());
190           } else if (a.isDouble() && b.isInt()) {
191             out = EValue(a.toDouble() / b.toInt());
192           } else {
193             // TODO Fail using runtime context
194             ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag);
195           }
196         }),
197 
198     // executorch_prim::sym_float.Scalar(Scalar) -> Scalar
199     Kernel(
200         "executorch_prim::sym_float.Scalar",
__anon9c72e3c90b02() 201         [](KernelRuntimeContext& context, EValue** stack) {
202           // can't use macro because of custom casting behavior
203           // TODO: Now that we are reliably generating conversion operators,
204           // we can remove the mixed type handling for other operators
205           (void)context;
206           EValue& a = *stack[0];
207           EValue& out = *stack[1];
208           if (a.isInt()) {
209             out = EValue(static_cast<double>(a.toInt()));
210           } else if (a.isDouble()) {
211             // TODO: This should be impossible
212             out = EValue(a.toDouble());
213           } else {
214             // TODO Fail using runtime context
215             ET_CHECK_MSG(false, "%zu", (size_t)a.tag);
216           }
217         }),
218 
219     // executorch_prim::eq.Scalar(Scalar, Scalar) -> bool
220     Kernel(
221         "executorch_prim::eq.Scalar",
__anon9c72e3c90c02() 222         [](KernelRuntimeContext& context, EValue** stack) {
223           BOOLEAN_ET_PRIM_OP(==, stack, context);
224         }),
225 
226     // executorch_prim::gt.Scalar(Scalar, Scalar) -> bool
227     Kernel(
228         "executorch_prim::gt.Scalar",
__anon9c72e3c90d02() 229         [](KernelRuntimeContext& context, EValue** stack) {
230           BOOLEAN_ET_PRIM_OP(>, stack, context);
231         }),
232 
233     // executorch_prim::lt.Scalar(Scalar, Scalar) -> bool
234     Kernel(
235         "executorch_prim::lt.Scalar",
__anon9c72e3c90e02() 236         [](KernelRuntimeContext& context, EValue** stack) {
237           BOOLEAN_ET_PRIM_OP(<, stack, context);
238         }),
239 
240     // executorch_prim::ge.Scalar(Scalar, Scalar) -> bool
241     Kernel(
242         "executorch_prim::ge.Scalar",
__anon9c72e3c90f02() 243         [](KernelRuntimeContext& context, EValue** stack) {
244           BOOLEAN_ET_PRIM_OP(>=, stack, context);
245         }),
246 
247     // executorch_prim::le.Scalar(Scalar, Scalar) -> bool
248     Kernel(
249         "executorch_prim::le.Scalar",
__anon9c72e3c91002() 250         [](KernelRuntimeContext& context, EValue** stack) {
251           BOOLEAN_ET_PRIM_OP(<=, stack, context);
252         }),
253     // executorch_prim::neg.Scalar(Scalar) -> Scalar
254     Kernel(
255         "executorch_prim::neg.Scalar",
__anon9c72e3c91102() 256         [](KernelRuntimeContext& context, EValue** stack) {
257           (void)context;
258           EValue& a = *stack[0];
259           EValue& out = *stack[1];
260           if (a.isInt()) {
261             out = EValue(-a.toInt());
262           } else if (a.isDouble()) {
263             out = EValue(-a.toDouble());
264           } else {
265             // TODO Fail using runtime context
266             ET_CHECK_MSG(false, "%zu", (size_t)a.tag);
267           }
268         }),
269 
270     // executorch_prim::floordiv.int(int, int) -> int
271     Kernel(
272         "executorch_prim::floordiv.int",
__anon9c72e3c91202() 273         [](KernelRuntimeContext& context, EValue** stack) {
274           (void)context;
275           EValue& a = *stack[0];
276           EValue& b = *stack[1];
277           EValue& out = *stack[2];
278           out = EValue(a.toInt() / b.toInt());
279         }),
280 
281     // executorch_prim::mod.int(int, int) -> int
282     Kernel(
283         "executorch_prim::mod.int",
__anon9c72e3c91302() 284         [](KernelRuntimeContext& context, EValue** stack) {
285           (void)context;
286           EValue& a = *stack[0];
287           EValue& b = *stack[1];
288           EValue& out = *stack[2];
289           out = EValue(a.toInt() % b.toInt());
290         }),
291 
292     // executorch_prim::mod.Scalar(Scalar, Scalar) -> Scalar
293     Kernel(
294         "executorch_prim::mod.Scalar",
__anon9c72e3c91402() 295         [](KernelRuntimeContext& context, EValue** stack) {
296           (void)context;
297           EValue& a = *stack[0];
298           EValue& b = *stack[1];
299           EValue& out = *stack[2];
300           if (a.isInt() && b.isInt()) {
301             out = EValue(a.toInt() % b.toInt());
302           } else {
303             ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag);
304           }
305         }),
306 
307     // ceil.Scalar(Scalar a) -> Scalar
308     Kernel(
309         "executorch_prim::ceil.Scalar",
__anon9c72e3c91502() 310         [](KernelRuntimeContext& context, EValue** stack) {
311           (void)context;
312           EValue& a = *stack[0];
313           EValue& out = *stack[1];
314           if (a.isDouble()) {
315             out = EValue(static_cast<int64_t>(ceil(a.toDouble())));
316           } else {
317             ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag);
318           }
319         }),
320 
321     // round.Scalar(Scalar a) -> Scalar
322     Kernel(
323         "executorch_prim::round.Scalar",
__anon9c72e3c91602() 324         [](KernelRuntimeContext& context, EValue** stack) {
325           (void)context;
326           EValue& a = *stack[0];
327           EValue& out = *stack[1];
328           if (a.isDouble()) {
329             // Round half to even to match Python round(). Need an explicit
330             // implementation as not all platforms support fenv rounding modes.
331             // See
332             // https://codeyarns.com/tech/2018-08-17-how-to-round-half-to-even.html
333             const auto val = a.toDouble();
334             const auto r = round(val);
335             const auto d = r - val;
336             auto res = 0.0;
337 
338             if (std::abs(d) != 0.5) {
339               res = r;
340             } else if (fmod(r, 2.0) == 0.0) {
341               res = r;
342             } else {
343               res = val - d;
344             }
345 
346             out = EValue(static_cast<int64_t>(res));
347           } else {
348             ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag);
349           }
350         }),
351 
352     // trunc.Scalar(Scalar a) -> Scalar
353     Kernel(
354         "executorch_prim::trunc.Scalar",
__anon9c72e3c91702() 355         [](KernelRuntimeContext& context, EValue** stack) {
356           (void)context;
357           EValue& a = *stack[0];
358           EValue& out = *stack[1];
359           if (a.isDouble()) {
360             out = EValue(static_cast<int64_t>(trunc(a.toDouble())));
361           } else {
362             ET_CHECK_MSG(false, "%zu", (size_t)a.tag);
363           }
364         }),
365 
366     // executorch_prim::et_copy_index.tensor(tensor, tensor) -> tensor
367     Kernel(
368         "executorch_prim::et_copy_index.tensor",
__anon9c72e3c91802() 369         [](KernelRuntimeContext& context, EValue** stack) {
370           et_copy_index(context, stack);
371         }),
372     // executorch_prim::et_view.default(Tensor, int[]) -> Tensor
373     Kernel(
374         "executorch_prim::et_view.default",
__anon9c72e3c91902() 375         [](KernelRuntimeContext& context, EValue** stack) {
376           et_view(context, stack);
377         }),
378 
379 };
380 
381 executorch::runtime::Span<const executorch::runtime::Kernel> kernel_span(
382     prim_ops,
383     prim_ops + sizeof(prim_ops) / sizeof(Kernel));
384 
385 // Return value not used. Keep the static variable assignment to register
386 // operators in static initialization time.
387 auto success_with_kernel_reg =
388     executorch::runtime::register_kernels(kernel_span);
389 
390 } // namespace
391 } // namespace function
392 } // namespace executor
393 } // namespace torch
394