xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/register_ops_utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Context.h>
4 #include <c10/core/DeviceType.h>
5 #include <torch/csrc/autograd/autograd.h>
6 #include <torch/csrc/autograd/edge.h>
7 #include <torch/csrc/autograd/function.h>
8 #include <torch/csrc/autograd/generated/variable_factories.h>
9 #include <torch/csrc/autograd/variable.h>
10 #include <torch/csrc/jit/api/compilation_unit.h>
11 #include <torch/csrc/jit/api/module.h>
12 #include <torch/csrc/jit/frontend/error_report.h>
13 #include <torch/csrc/jit/ir/ir.h>
14 #include <torch/csrc/jit/mobile/register_ops_common_utils.h>
15 #include <torch/csrc/jit/runtime/custom_operator.h>
16 #include <torch/csrc/jit/runtime/graph_executor.h>
17 #include <torch/csrc/jit/runtime/jit_exception.h>
18 #include <torch/csrc/jit/runtime/logging.h>
19 #include <torch/csrc/jit/runtime/operator.h>
20 #include <torch/csrc/jit/runtime/print_handler.h>
21 #include <torch/csrc/jit/runtime/profiling_record.h>
22 #include <torch/csrc/jit/runtime/vararg_functions.h>
23 #include <torch/csrc/jit/serialization/pickle.h>
24 
25 #include <ATen/ExpandUtils.h>
26 #include <ATen/Parallel.h>
27 #include <ATen/WrapDimUtils.h>
28 #include <ATen/core/Dict.h>
29 #include <ATen/core/Generator.h>
30 #include <ATen/core/ivalue.h>
31 #include <c10/core/Device.h>
32 #include <c10/core/thread_pool.h>
33 #include <c10/util/SmallVector.h>
34 #include <c10/util/irange.h>
35 
36 namespace torch::jit {
aliasAnalysisFromSchema()37 constexpr inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
38   return c10::AliasAnalysisKind::FROM_SCHEMA;
39 }
40 
aliasAnalysisConservative()41 constexpr inline c10::AliasAnalysisKind aliasAnalysisConservative() {
42   return c10::AliasAnalysisKind::CONSERVATIVE;
43 }
44 
aliasAnalysisSpecialCase()45 constexpr inline c10::AliasAnalysisKind aliasAnalysisSpecialCase() {
46   return c10::AliasAnalysisKind::INTERNAL_SPECIAL_CASE;
47 }
48 
49 template <class T>
make_result_list(const TypePtr & elemType)50 c10::List<T> make_result_list(const TypePtr& elemType) {
51   return c10::List<T>();
52 }
53 
54 template <>
55 c10::impl::GenericList make_result_list<IValue>(const TypePtr& elemType);
56 
57 // As described in https://docs.python.org/3/library/functions.html#round
58 // When a number is exactly halfway between two integers, python builtin round
59 // function will round to even number. We use round(x/2)*2 to handle the
60 // special halfway case. For positive 'x', round(x/2)*2 =
61 // round((x_e + x_r)/2)*2 = x_e + round(x_r/2)*2, where x_e is an even integer,
62 // x_r is either 0.5 of 1.5, round(x_r/2)*2 results a 0 or 2, so the final
63 // result will always be a even number. Due to symmetricity, it also applies to
64 // negative cases.
round_to_even(double a)65 inline double round_to_even(double a) {
66   return a - std::floor(a) == 0.5 ? (std::round(a * 0.5) * 2.0) : std::round(a);
67 }
68 
69 // using the rules from python_arg_parser FunctionParameter::check
70 // tensor cannot have grad set, tensor must be 0 dim,
71 // and if the dest is an int the source must be integral type
72 void checkImplicitTensorToNum(const at::Tensor& t, bool toInt);
73 
floordiv(int64_t a,int64_t b)74 static C10_UNUSED int64_t floordiv(int64_t a, int64_t b) {
75   if (b == 0) {
76     throw std::runtime_error("division by 0");
77   }
78   if ((a > 0) == (b > 0)) {
79     // simple case, both have same sign
80     return a / b;
81   } else {
82     // in python division rounds down, it doesn't not truncate like in c++
83     auto r = lldiv(a, b);
84     return (r.rem) ? r.quot - 1 : r.quot;
85   }
86 }
87 TORCH_API void checkDoubleInRange(double a);
floor(double a)88 static C10_UNUSED int64_t floor(double a) {
89   checkDoubleInRange(a);
90   return std::floor(a);
91 }
ceil(double a)92 static C10_UNUSED int64_t ceil(double a) {
93   checkDoubleInRange(a);
94   return std::ceil(a);
95 }
96 
gcd(int64_t a,int64_t b)97 static C10_UNUSED int64_t gcd(int64_t a, int64_t b) {
98   while (b != 0) {
99     int64_t r = a % b;
100     a = b;
101     b = r;
102   }
103   // in python gcd returns non-negative values
104   return std::abs(a);
105 }
106 
107 int64_t partProduct(int n, int m);
108 
109 void loop(int n, int64_t& p, int64_t& r);
110 
111 int nminussumofbits(int v);
112 
113 int64_t factorial(int n);
114 static const double degToRad = std::acos(-1.0) / 180.0;
115 static const double radToDeg = 180.0 / std::acos(-1.0);
116 double degrees(double x);
117 double radians(double x);
118 
119 // Convert an python index (which may be negative) into an index usable for a
120 // C++ container
121 
122 // Equivalent to list.at(idx)
123 template <typename T>
decltype(auto)124 decltype(auto) getItem(const c10::List<T>& list, int64_t idx) {
125   const int64_t list_size = list.size();
126   const int64_t normalized_idx = normalizeIndex(idx, list_size);
127   if (normalized_idx < 0 || normalized_idx >= list_size) {
128     throw std::out_of_range("list index out of range");
129   }
130   return list.get(normalized_idx);
131 }
132 
133 template <typename T>
setItem(const c10::List<T> & list,int64_t idx,T && value)134 void setItem(const c10::List<T>& list, int64_t idx, T&& value) {
135   const int64_t list_size = list.size();
136   const int64_t normalized_idx = normalizeIndex(idx, list_size);
137   if (normalized_idx < 0 || normalized_idx >= list_size) {
138     throw std::out_of_range("list index out of range");
139   }
140   list.set(normalized_idx, std::forward<T>(value));
141 }
142 
143 void listAppend(Stack& stack);
144 
145 void listReverse(Stack& stack);
146 
147 template <typename T>
minList(Stack & stack)148 void minList(Stack& stack) {
149   c10::List<T> a = pop(stack).to<c10::List<T>>();
150   c10::List<T> b = pop(stack).to<c10::List<T>>();
151 
152   size_t min_size = std::min(a.size(), b.size());
153   for (const auto i : c10::irange(min_size)) {
154     if (a[i] == b[i]) {
155       continue;
156     }
157 
158     push(stack, a[i] < b[i] ? a : b);
159     return;
160   }
161 
162   push(stack, b.size() < a.size() ? b : a);
163 }
164 
165 template <typename T>
maxList(Stack & stack)166 void maxList(Stack& stack) {
167   c10::List<T> a = pop(stack).to<c10::List<T>>();
168   c10::List<T> b = pop(stack).to<c10::List<T>>();
169 
170   size_t min_size = std::min(a.size(), b.size());
171   for (const auto i : c10::irange(min_size)) {
172     if (a[i] == b[i]) {
173       continue;
174     }
175 
176     push(stack, a[i] > b[i] ? a : b);
177     return;
178   }
179 
180   push(stack, b.size() > a.size() ? b : a);
181 }
182 
183 void listPopImpl(Stack& stack, const char* empty_message);
184 
185 void listPop(Stack& stack);
186 
187 void listClear(Stack& stack);
188 
189 void listDelete(Stack& stack);
190 
191 void listInsert(Stack& stack);
192 
193 template <typename T>
listRemove(Stack & stack)194 void listRemove(Stack& stack) {
195   T elem = pop(stack).to<T>();
196   c10::List<T> list = pop(stack).to<c10::List<T>>();
197 
198   auto pos = std::find(list.begin(), list.end(), elem);
199 
200   if (pos != list.end()) {
201     list.erase(pos);
202   } else {
203     AT_ERROR("list.remove(x): x not in list");
204   }
205 }
206 
207 template <typename T>
listMin(Stack & stack)208 void listMin(Stack& stack) {
209   c10::List<T> list = pop(stack).to<c10::List<T>>();
210   size_t list_size = list.size();
211   if (list_size == 0) {
212     throw std::runtime_error("min() arg is an empty sequence");
213   }
214 
215   T min_elem = list[0];
216   for (const auto i : c10::irange(1, list_size)) {
217     T elem = list[i];
218     min_elem = elem < min_elem ? elem : min_elem;
219   }
220 
221   stack.push_back(min_elem);
222 }
223 
224 template <typename T>
listMax(Stack & stack)225 void listMax(Stack& stack) {
226   c10::List<T> list = pop(stack).to<c10::List<T>>();
227   size_t list_size = list.size();
228   if (list_size == 0) {
229     throw std::runtime_error("max() arg is an empty sequence");
230   }
231 
232   T max_elem = list[0];
233   for (const auto i : c10::irange(1, list_size)) {
234     T elem = list[i];
235     max_elem = elem > max_elem ? elem : max_elem;
236   }
237 
238   stack.push_back(max_elem);
239 }
240 
241 template <>
242 void listRemove<at::Tensor>(Stack& stack);
243 
244 template <typename T>
listIndex(Stack & stack)245 void listIndex(Stack& stack) {
246   T elem = pop(stack).to<T>();
247   c10::List<T> list = pop(stack).to<c10::List<T>>();
248 
249   auto pos = std::find(list.begin(), list.end(), elem);
250 
251   if (pos != list.end()) {
252     push(stack, static_cast<int64_t>(std::distance(list.begin(), pos)));
253   } else {
254     AT_ERROR("'", elem, "' is not in list");
255   }
256 }
257 
258 template <>
259 void listIndex<at::Tensor>(Stack& stack);
260 
261 template <typename T>
listCount(Stack & stack)262 void listCount(Stack& stack) {
263   T elem = pop(stack).to<T>();
264   c10::List<T> list = pop(stack).to<c10::List<T>>();
265 
266   const int64_t count = std::count(list.begin(), list.end(), elem);
267   push(stack, count);
268 }
269 
270 template <>
271 void listCount<at::Tensor>(Stack& stack);
272 
273 void listExtend(Stack& stack);
274 
275 void listCopy(Stack& stack);
276 
277 void listSelect(Stack& stack);
278 
279 void listLen(Stack& stack);
280 
281 template <typename T>
listEq(Stack & stack)282 void listEq(Stack& stack) {
283   c10::List<T> b = pop(stack).to<c10::List<T>>();
284   c10::List<T> a = pop(stack).to<c10::List<T>>();
285   push(stack, a == b);
286 }
287 
288 template <typename T>
listNe(Stack & stack)289 void listNe(Stack& stack) {
290   c10::List<T> b = pop(stack).to<c10::List<T>>();
291   c10::List<T> a = pop(stack).to<c10::List<T>>();
292   push(stack, a != b);
293 }
294 
tensor_list_equal(const c10::List<at::Tensor> & a,const c10::List<at::Tensor> & b)295 inline bool tensor_list_equal(
296     const c10::List<at::Tensor>& a,
297     const c10::List<at::Tensor>& b) {
298   if (a.size() != b.size()) {
299     return false;
300   }
301 
302   for (const auto i : c10::irange(a.size())) {
303     const at::Tensor& a_element = a[i];
304     const at::Tensor& b_element = b[i];
305     // This preserves Python's semantics, which uses eq() to compare two
306     // elements, then passes the result to bool().
307     // see: https://docs.python.org/3.4/reference/datamodel.html#object.__ge__
308     const auto cmp_result = a_element.eq(b_element);
309     if (!at::native::is_nonzero(cmp_result)) {
310       return false;
311     }
312   }
313 
314   return true;
315 }
316 
317 // Specialization for at::Tensor, since it doesn't define operator==
318 template <>
319 void listEq<at::Tensor>(Stack& stack);
320 
321 // Specialization for at::Tensor, since it doesn't define operator==
322 template <>
323 void listNe<at::Tensor>(Stack& stack);
324 
325 void listList(Stack& stack);
326 
327 template <typename T>
listContains(Stack & stack)328 void listContains(Stack& stack) {
329   auto key = pop(stack).to<T>();
330   auto list = pop(stack).to<c10::List<T>>();
331   // NOLINTNEXTLINE(performance-implicit-conversion-in-loop)
332   for (const T& item : list) {
333     if (item == key) {
334       push(stack, true);
335       return;
336     }
337   }
338   push(stack, false);
339 }
340 
341 void listAdd(Stack& stack);
342 
343 void listInplaceAdd(Stack& stack);
344 
345 void listMulIntLeftInPlace(Stack& stack);
346 
347 void listMulIntLeft(Stack& stack);
348 
349 void listMulIntRight(Stack& stack);
350 
351 void listSlice(Stack& stack);
352 
353 template <typename T>
listSort(Stack & stack)354 void listSort(Stack& stack) {
355   bool reverse = pop(stack).toBool();
356   c10::List<T> list = pop(stack).to<c10::List<T>>();
357   std::sort(list.begin(), list.end(), [reverse](const T& a, const T& b) {
358     // FBCode errors without this check - "strict weak ordering"
359     // TODO: remove when possible, since it just slows down
360     // sorting and doesn't do anything useful
361     if (a == b) {
362       return false;
363     }
364     return (a < b) != reverse;
365   });
366 }
367 
368 // Specialization for at::Tensor
369 template <>
370 void listSort<at::Tensor>(Stack& stack);
371 
372 template <typename T>
listCopyAndSort(Stack & stack)373 void listCopyAndSort(Stack& stack) {
374   c10::List<T> list = pop(stack).to<c10::List<T>>();
375   auto list_copied = list.copy();
376   std::sort(list_copied.begin(), list_copied.end(), [](const T& a, const T& b) {
377     // "strict weak ordering" issue - see other sort
378     if (a == b) {
379       return false;
380     }
381     return a < b;
382   });
383   push(stack, list_copied);
384 }
385 
386 // Specialization for at::Tensor
387 template <>
388 void listCopyAndSort<at::Tensor>(Stack& stack);
389 
390 void listSetItem(Stack& stack);
391 
392 struct OperatorGeneratorArgs {
393   const char* schema_str;
394   bool isOperationCreator;
395   union {
396     void (*operation)(Stack&);
397     OperationCreator operationCreator;
398   };
399   AliasAnalysisKind aliasAnalysis;
400 
OperatorGeneratorArgsOperatorGeneratorArgs401   explicit constexpr OperatorGeneratorArgs(
402       torch::detail::SelectiveStr<true> schema_str,
403       void (*op)(Stack&),
404       AliasAnalysisKind aa)
405       : schema_str(schema_str),
406         isOperationCreator(false),
407         operation(op),
408         aliasAnalysis(aa) {}
409 
OperatorGeneratorArgsOperatorGeneratorArgs410   explicit constexpr OperatorGeneratorArgs(
411       torch::detail::SelectiveStr<true> schema_str,
412       OperationCreator opCreator,
413       AliasAnalysisKind aa)
414       : schema_str(schema_str),
415         isOperationCreator(true),
416         operationCreator(opCreator),
417         aliasAnalysis(aa) {}
418 
419   template <typename... Args>
OperatorGeneratorArgsOperatorGeneratorArgs420   explicit constexpr OperatorGeneratorArgs(
421       torch::detail::SelectiveStr<false>,
422       Args...)
423       : schema_str(nullptr),
424         isOperationCreator(false),
425         operation(nullptr),
426         aliasAnalysis(AliasAnalysisKind::INTERNAL_SPECIAL_CASE) {}
427 };
428 
429 #define DEFINE_GENERIC_BINARY_OP(                                             \
430     aten_op, op, int_float_result, complex_result)                            \
431   OperatorGeneratorArgs(                                                      \
432       TORCH_SELECTIVE_SCHEMA(#aten_op                                         \
433                              ".int_int(int a, int b) -> " #int_float_result), \
434       [](Stack& stack) {                                                      \
435         int64_t a, b;                                                         \
436         pop(stack, a, b);                                                     \
437         push(stack, op);                                                      \
438       },                                                                      \
439       aliasAnalysisFromSchema()),                                             \
440       OperatorGeneratorArgs(                                                  \
441           TORCH_SELECTIVE_SCHEMA(                                             \
442               #aten_op                                                        \
443               ".float_float(float a, float b) -> " #int_float_result),        \
444           [](Stack& stack) {                                                  \
445             double a, b;                                                      \
446             pop(stack, a, b);                                                 \
447             push(stack, op);                                                  \
448           },                                                                  \
449           aliasAnalysisFromSchema()),                                         \
450       OperatorGeneratorArgs(                                                  \
451           TORCH_SELECTIVE_SCHEMA(                                             \
452               #aten_op                                                        \
453               ".complex_complex(complex a, complex b) -> " #complex_result),  \
454           [](Stack& stack) {                                                  \
455             c10::complex<double> a, b;                                        \
456             pop(stack, a, b);                                                 \
457             push(stack, op);                                                  \
458           },                                                                  \
459           aliasAnalysisFromSchema())
460 
461 // define implementations for primitive number ops
462 #define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \
463   OperatorGeneratorArgs(                                                       \
464       TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a, int b) -> " #int_result),   \
465       [](Stack& stack) {                                                       \
466         int64_t a, b;                                                          \
467         pop(stack, a, b);                                                      \
468         push(stack, int_op);                                                   \
469       },                                                                       \
470       aliasAnalysisFromSchema()),                                              \
471       OperatorGeneratorArgs(                                                   \
472           TORCH_SELECTIVE_SCHEMA(                                              \
473               #aten_op ".float(float a, float b) -> " #float_result),          \
474           [](Stack& stack) {                                                   \
475             double a, b;                                                       \
476             pop(stack, a, b);                                                  \
477             push(stack, float_op);                                             \
478           },                                                                   \
479           aliasAnalysisFromSchema())
480 
481 #define DEFINE_INT_FLOAT_OP(aten_op, op, result)                            \
482   OperatorGeneratorArgs(                                                    \
483       TORCH_SELECTIVE_SCHEMA(#aten_op                                       \
484                              ".int_float(int a, float b) -> " #result),     \
485       [](Stack& stack) {                                                    \
486         int64_t a;                                                          \
487         double b;                                                           \
488         pop(stack, a, b);                                                   \
489         push(stack, op);                                                    \
490       },                                                                    \
491       aliasAnalysisFromSchema()),                                           \
492       OperatorGeneratorArgs(                                                \
493           TORCH_SELECTIVE_SCHEMA(#aten_op                                   \
494                                  ".float_int(float a, int b) -> " #result), \
495           [](Stack& stack) {                                                \
496             double a;                                                       \
497             int64_t b;                                                      \
498             pop(stack, a, b);                                               \
499             push(stack, op);                                                \
500           },                                                                \
501           aliasAnalysisFromSchema())
502 
503 #define DEFINE_INT_OP(aten_op, op)                                  \
504   OperatorGeneratorArgs(                                            \
505       TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a, int b) -> int"), \
506       [](Stack& stack) {                                            \
507         int64_t a, b;                                               \
508         pop(stack, a, b);                                           \
509         push(stack, op); /* NOLINT(hicpp-signed-bitwise) */         \
510       },                                                            \
511       aliasAnalysisFromSchema())
512 
513 #define DEFINE_STR_CMP_OP(aten_op, op)                               \
514   OperatorGeneratorArgs(                                             \
515       TORCH_SELECTIVE_SCHEMA(#aten_op ".str(str a, str b) -> bool"), \
516       [](Stack& stack) {                                             \
517         auto b = pop(stack).toStringRef();                           \
518         auto a = pop(stack).toStringRef();                           \
519         push(stack, op);                                             \
520       },                                                             \
521       aliasAnalysisFromSchema())
522 
523 // define a primitive op over Scalar operands.
524 // it's necessary to register this overload following
525 // int/float variations to avoid trapping Scalar args
526 // in unintended implicit conversions
527 #define DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION_GENERIC(          \
528     aten_op, int_op, float_op, result, string_val)                \
529   OperatorGeneratorArgs(                                          \
530       TORCH_SELECTIVE_SCHEMA(#aten_op string_val                  \
531                              "(Scalar a, Scalar b) -> " #result), \
532       [](Stack& stack) {                                          \
533         IValue x, y;                                              \
534         pop(stack, x, y);                                         \
535         if (x.isDouble()) {                                       \
536           if (y.isDouble()) {                                     \
537             double a = x.toDouble();                              \
538             double b = y.toDouble();                              \
539             push(stack, float_op);                                \
540           } else {                                                \
541             double a = x.toDouble();                              \
542             int64_t b = y.toInt();                                \
543             push(stack, float_op);                                \
544           }                                                       \
545         } else {                                                  \
546           if (y.isDouble()) {                                     \
547             int64_t a = x.toInt();                                \
548             double b = y.toDouble();                              \
549             push(stack, float_op);                                \
550           } else {                                                \
551             int64_t a = x.toInt();                                \
552             int64_t b = y.toInt();                                \
553             push(stack, int_op);                                  \
554           }                                                       \
555         }                                                         \
556       },                                                          \
557       aliasAnalysisFromSchema())
558 
559 #define DEFINE_SCALAR_BINARY_OP(aten_op, int_op, float_op, result) \
560   DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION_GENERIC(                 \
561       aten_op, int_op, float_op, result, "")
562 
563 #define DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION(   \
564     aten_op, int_op, float_op, result)             \
565   DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION_GENERIC( \
566       aten_op, int_op, float_op, result, ".Scalar_Scalar")
567 
568 #define DEFINE_BINARY_OP(aten_op, op)             \
569   DEFINE_GENERIC_OP(aten_op, op, op, int, float), \
570       DEFINE_INT_FLOAT_OP(aten_op, op, float),    \
571       DEFINE_SCALAR_BINARY_OP(aten_op, op, op, Scalar)
572 
573 #define DEFINE_BINARY_FLOAT_OP(aten_op, op)         \
574   DEFINE_GENERIC_OP(aten_op, op, op, float, float), \
575       DEFINE_INT_FLOAT_OP(aten_op, op, float),      \
576       DEFINE_SCALAR_BINARY_OP(aten_op, op, op, float)
577 
578 #define DEFINE_COMPARISON_OP(aten_op, op)             \
579   DEFINE_GENERIC_OP(aten_op, op, op, bool, bool),     \
580       DEFINE_INT_FLOAT_OP(aten_op, op, bool),         \
581       DEFINE_SCALAR_BINARY_OP(aten_op, op, op, bool), \
582       DEFINE_STR_CMP_OP(aten_op, op)
583 
584 #define DEFINE_UNARY_INT_OP(aten_op, op, result)                  \
585   OperatorGeneratorArgs(                                          \
586       TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a) -> " #result), \
587       [](Stack& stack) {                                          \
588         int64_t a;                                                \
589         pop(stack, a);                                            \
590         push(stack, op);                                          \
591       },                                                          \
592       aliasAnalysisFromSchema())
593 
594 #define DEFINE_UNARY_FLOAT_OP(aten_op, op, result)                    \
595   OperatorGeneratorArgs(                                              \
596       TORCH_SELECTIVE_SCHEMA(#aten_op ".float(float a) -> " #result), \
597       [](Stack& stack) {                                              \
598         double a;                                                     \
599         pop(stack, a);                                                \
600         push(stack, op);                                              \
601       },                                                              \
602       aliasAnalysisFromSchema())
603 
604 #define DEFINE_UNARY_OP(aten_op, op, int_result, float_result)            \
605   DEFINE_UNARY_INT_OP(aten_op, op, int_result),                           \
606       DEFINE_UNARY_FLOAT_OP(aten_op, op, float_result),                   \
607       OperatorGeneratorArgs(                                              \
608           TORCH_SELECTIVE_SCHEMA(#aten_op ".Scalar(Scalar a) -> Scalar"), \
609           [](Stack& stack) {                                              \
610             IValue x;                                                     \
611             pop(stack, x);                                                \
612             if (x.isDouble()) {                                           \
613               double a = x.toDouble();                                    \
614               push(stack, static_cast<float_result>(op));                 \
615             } else {                                                      \
616               int64_t a = x.toInt();                                      \
617               push(stack, static_cast<int_result>(op));                   \
618             }                                                             \
619           },                                                              \
620           aliasAnalysisFromSchema())
621 #define DEFINE_BOOL_OP(aten_op, op)                                     \
622   OperatorGeneratorArgs(                                                \
623       TORCH_SELECTIVE_SCHEMA(#aten_op ".bool(bool a, bool b) -> bool"), \
624       [](Stack& stack) {                                                \
625         bool a, b;                                                      \
626         pop(stack, a, b);                                               \
627         push(stack, op);                                                \
628       },                                                                \
629       aliasAnalysisFromSchema())
630 #define DEFINE_STRING_OP(op_name, string_op, result)                    \
631   OperatorGeneratorArgs(                                                \
632       TORCH_SELECTIVE_SCHEMA(#op_name ".str(str a, str b) ->" #result), \
633       [](Stack& stack) {                                                \
634         auto b = pop(stack).toStringRef();                              \
635         auto a = pop(stack).toStringRef();                              \
636         push(stack, string_op);                                         \
637       },                                                                \
638       aliasAnalysisFromSchema())
639 
640 //-----------------------------------------------------------------------------
641 //-----------------------------------------------------------------------------
642 //-----------------------------------------------------------------------------
643 //-----------------------------------------------------------------------------
644 #define DEFINE_UNARY_COMPLEX_OP(aten_op, op, result)                      \
645   OperatorGeneratorArgs(                                                  \
646       TORCH_SELECTIVE_SCHEMA(#aten_op ".complex(complex a) -> " #result), \
647       [](Stack& stack) {                                                  \
648         c10::complex<double> a;                                           \
649         pop(stack, a);                                                    \
650         push(stack, op);                                                  \
651       },                                                                  \
652       aliasAnalysisFromSchema())
653 
654 // Some complex unary ops (like abs, angle) return real valued output, but most
655 // other unary ops return complex valued output. So, this macro is used in the
656 // former case where we can explicitly pass complex_result_cast argument, which
657 // is set to c10::complex<float> in the macro `DEFINE_UNARY_OP_WITH_COMPLEX`
658 // defined below.
659 #define DEFINE_UNARY_OP_WITH_COMPLEX_CAST(                                \
660     aten_op,                                                              \
661     op,                                                                   \
662     int_result,                                                           \
663     float_result,                                                         \
664     complex_result,                                                       \
665     complex_result_cast)                                                  \
666   DEFINE_UNARY_INT_OP(aten_op, op, int_result),                           \
667       DEFINE_UNARY_FLOAT_OP(aten_op, op, float_result),                   \
668       DEFINE_UNARY_COMPLEX_OP(aten_op, op, complex_result),               \
669       OperatorGeneratorArgs(                                              \
670           TORCH_SELECTIVE_SCHEMA(#aten_op ".Scalar(Scalar a) -> Scalar"), \
671           [](Stack& stack) {                                              \
672             IValue x;                                                     \
673             pop(stack, x);                                                \
674             if (x.isDouble()) {                                           \
675               double a = x.toDouble();                                    \
676               push(stack, static_cast<float_result>(op));                 \
677             } else if (x.isComplexDouble()) {                             \
678               c10::complex<double> a = x.toComplexDouble();               \
679               push(stack, static_cast<complex_result_cast>(op));          \
680             } else {                                                      \
681               int64_t a = x.toInt();                                      \
682               push(stack, static_cast<int_result>(op));                   \
683             }                                                             \
684           },                                                              \
685           aliasAnalysisFromSchema())
686 
687 #define DEFINE_UNARY_OP_WITH_COMPLEX(aten_op, op, int_result, float_result) \
688   DEFINE_UNARY_OP_WITH_COMPLEX_CAST(                                        \
689       aten_op, op, int_result, float_result, complex, c10::complex<double>)
690 
691 #define DEFINE_GENERIC_OP_WITH_COMPLEX(                                       \
692     aten_op,                                                                  \
693     int_op,                                                                   \
694     float_op,                                                                 \
695     complex_op,                                                               \
696     int_result,                                                               \
697     float_result,                                                             \
698     complex_result)                                                           \
699   OperatorGeneratorArgs(                                                      \
700       TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a, int b) -> " #int_result),  \
701       [](Stack& stack) {                                                      \
702         int64_t a, b;                                                         \
703         pop(stack, a, b);                                                     \
704         push(stack, int_op);                                                  \
705       },                                                                      \
706       aliasAnalysisFromSchema()),                                             \
707       OperatorGeneratorArgs(                                                  \
708           TORCH_SELECTIVE_SCHEMA(                                             \
709               #aten_op ".complex(complex a, complex b) -> " #complex_result), \
710           [](Stack& stack) {                                                  \
711             c10::complex<double> a, b;                                        \
712             pop(stack, a, b);                                                 \
713             push(stack, complex_op);                                          \
714           },                                                                  \
715           aliasAnalysisFromSchema()),                                         \
716       OperatorGeneratorArgs(                                                  \
717           TORCH_SELECTIVE_SCHEMA(                                             \
718               #aten_op ".float(float a, float b) -> " #float_result),         \
719           [](Stack& stack) {                                                  \
720             double a, b;                                                      \
721             pop(stack, a, b);                                                 \
722             push(stack, float_op);                                            \
723           },                                                                  \
724           aliasAnalysisFromSchema())
725 
726 #define DEFINE_INT_COMPLEX_OP(aten_op, op, result)                          \
727   OperatorGeneratorArgs(                                                    \
728       TORCH_SELECTIVE_SCHEMA(#aten_op                                       \
729                              ".int_complex(int a, complex b) -> " #result), \
730       [](Stack& stack) {                                                    \
731         int64_t a;                                                          \
732         c10::complex<double> b;                                             \
733         pop(stack, a, b);                                                   \
734         push(stack, op);                                                    \
735       },                                                                    \
736       aliasAnalysisFromSchema()),                                           \
737       OperatorGeneratorArgs(                                                \
738           TORCH_SELECTIVE_SCHEMA(                                           \
739               #aten_op ".complex_int(complex a, int b) -> " #result),       \
740           [](Stack& stack) {                                                \
741             c10::complex<double> a;                                         \
742             int64_t b;                                                      \
743             pop(stack, a, b);                                               \
744             push(stack, op);                                                \
745           },                                                                \
746           aliasAnalysisFromSchema())
747 
748 #define DEFINE_FLOAT_COMPLEX_OP(aten_op, op, result)                      \
749   OperatorGeneratorArgs(                                                  \
750       TORCH_SELECTIVE_SCHEMA(                                             \
751           #aten_op ".float_complex(float a, complex b) -> " #result),     \
752       [](Stack& stack) {                                                  \
753         double a;                                                         \
754         c10::complex<double> b;                                           \
755         pop(stack, a, b);                                                 \
756         push(stack, op);                                                  \
757       },                                                                  \
758       aliasAnalysisFromSchema()),                                         \
759       OperatorGeneratorArgs(                                              \
760           TORCH_SELECTIVE_SCHEMA(                                         \
761               #aten_op ".complex_float(complex a, float b) -> " #result), \
762           [](Stack& stack) {                                              \
763             c10::complex<double> a;                                       \
764             double b;                                                     \
765             pop(stack, a, b);                                             \
766             push(stack, op);                                              \
767           },                                                              \
768           aliasAnalysisFromSchema())
769 
770 #define DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX_AVOID_COLLISION_GENERIC( \
771     aten_op, int_op, float_op, complex_op, result, string_val)        \
772   OperatorGeneratorArgs(                                              \
773       TORCH_SELECTIVE_SCHEMA(#aten_op string_val                      \
774                              "(Scalar a, Scalar b) -> " #result),     \
775       [](Stack& stack) {                                              \
776         IValue x, y;                                                  \
777         pop(stack, x, y);                                             \
778         if (x.isComplexDouble()) {                                    \
779           c10::complex<double> a = x.toComplexDouble();               \
780           if (y.isComplexDouble()) {                                  \
781             c10::complex<double> b = y.toComplexDouble();             \
782             push(stack, complex_op);                                  \
783           } else if (y.isDouble()) {                                  \
784             double b = y.toDouble();                                  \
785             push(stack, complex_op);                                  \
786           } else {                                                    \
787             int64_t b = y.toInt();                                    \
788             push(stack, complex_op);                                  \
789           }                                                           \
790         } else if (x.isDouble()) {                                    \
791           double a = x.toDouble();                                    \
792           if (y.isComplexDouble()) {                                  \
793             c10::complex<double> b = y.toComplexDouble();             \
794             push(stack, complex_op);                                  \
795           } else if (y.isDouble()) {                                  \
796             double b = y.toDouble();                                  \
797             push(stack, float_op);                                    \
798           } else {                                                    \
799             int64_t b = y.toInt();                                    \
800             push(stack, float_op);                                    \
801           }                                                           \
802         } else {                                                      \
803           int64_t a = x.toInt();                                      \
804           if (y.isComplexDouble()) {                                  \
805             c10::complex<double> b = y.toComplexDouble();             \
806             push(stack, complex_op);                                  \
807           } else if (y.isDouble()) {                                  \
808             double b = y.toDouble();                                  \
809             push(stack, float_op);                                    \
810           } else {                                                    \
811             int64_t b = y.toInt();                                    \
812             push(stack, int_op);                                      \
813           }                                                           \
814         }                                                             \
815       },                                                              \
816       aliasAnalysisFromSchema())
817 
818 #define DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX_WITHOUT_INT_COMPLEX_PAIR(     \
819     aten_op, int_op, float_op, complex_op, result)                         \
820   OperatorGeneratorArgs(                                                   \
821       TORCH_SELECTIVE_SCHEMA(#aten_op "(Scalar a, Scalar b) -> " #result), \
822       [](Stack& stack) {                                                   \
823         IValue x, y;                                                       \
824         pop(stack, x, y);                                                  \
825         if (x.isComplexDouble()) {                                         \
826           c10::complex<double> a = x.toComplexDouble();                    \
827           if (y.isComplexDouble()) {                                       \
828             c10::complex<double> b = y.toComplexDouble();                  \
829             push(stack, complex_op);                                       \
830           } else if (y.isDouble()) {                                       \
831             double b = y.toDouble();                                       \
832             push(stack, complex_op);                                       \
833           }                                                                \
834         } else if (x.isDouble()) {                                         \
835           double a = x.toDouble();                                         \
836           if (y.isComplexDouble()) {                                       \
837             c10::complex<double> b = y.toComplexDouble();                  \
838             push(stack, complex_op);                                       \
839           } else if (y.isDouble()) {                                       \
840             double b = y.toDouble();                                       \
841             push(stack, float_op);                                         \
842           } else {                                                         \
843             int64_t b = y.toInt();                                         \
844             push(stack, float_op);                                         \
845           }                                                                \
846         } else {                                                           \
847           int64_t a = x.toInt();                                           \
848           if (y.isDouble()) {                                              \
849             double b = y.toDouble();                                       \
850             push(stack, float_op);                                         \
851           } else if (y.isInt()) {                                          \
852             int64_t b = y.toInt();                                         \
853             push(stack, int_op);                                           \
854           }                                                                \
855         }                                                                  \
856       },                                                                   \
857       aliasAnalysisFromSchema())
858 
859 #define DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX(                   \
860     aten_op, int_op, float_op, complex_op, result)              \
861   DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX_AVOID_COLLISION_GENERIC( \
862       aten_op, int_op, float_op, complex_op, result, "")
863 
864 #define DEFINE_BINARY_OP_WITH_COMPLEX(aten_op, op)                          \
865   DEFINE_GENERIC_OP_WITH_COMPLEX(aten_op, op, op, op, int, float, complex), \
866       DEFINE_INT_COMPLEX_OP(aten_op, op, complex),                          \
867       DEFINE_FLOAT_COMPLEX_OP(aten_op, op, complex),                        \
868       DEFINE_INT_FLOAT_OP(aten_op, op, float),                              \
869       DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX(aten_op, op, op, op, Scalar)
870 
871 #define DEFINE_COMPARISON_OP_WITH_COMPLEX(aten_op, op)                   \
872   DEFINE_GENERIC_OP_WITH_COMPLEX(aten_op, op, op, op, bool, bool, bool), \
873       DEFINE_INT_FLOAT_OP(aten_op, op, bool),                            \
874       DEFINE_FLOAT_COMPLEX_OP(aten_op, op, bool),                        \
875       DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX_WITHOUT_INT_COMPLEX_PAIR(     \
876           aten_op, op, op, op, bool),                                    \
877       DEFINE_STR_CMP_OP(aten_op, op)
878 
879 TORCH_API at::Generator make_generator_for_device(
880     c10::Device device,
881     std::optional<int64_t> seed = std::nullopt);
882 
883 } // namespace torch::jit
884