1 #pragma once
2
3 #include <ATen/Context.h>
4 #include <c10/core/DeviceType.h>
5 #include <torch/csrc/autograd/autograd.h>
6 #include <torch/csrc/autograd/edge.h>
7 #include <torch/csrc/autograd/function.h>
8 #include <torch/csrc/autograd/generated/variable_factories.h>
9 #include <torch/csrc/autograd/variable.h>
10 #include <torch/csrc/jit/api/compilation_unit.h>
11 #include <torch/csrc/jit/api/module.h>
12 #include <torch/csrc/jit/frontend/error_report.h>
13 #include <torch/csrc/jit/ir/ir.h>
14 #include <torch/csrc/jit/mobile/register_ops_common_utils.h>
15 #include <torch/csrc/jit/runtime/custom_operator.h>
16 #include <torch/csrc/jit/runtime/graph_executor.h>
17 #include <torch/csrc/jit/runtime/jit_exception.h>
18 #include <torch/csrc/jit/runtime/logging.h>
19 #include <torch/csrc/jit/runtime/operator.h>
20 #include <torch/csrc/jit/runtime/print_handler.h>
21 #include <torch/csrc/jit/runtime/profiling_record.h>
22 #include <torch/csrc/jit/runtime/vararg_functions.h>
23 #include <torch/csrc/jit/serialization/pickle.h>
24
25 #include <ATen/ExpandUtils.h>
26 #include <ATen/Parallel.h>
27 #include <ATen/WrapDimUtils.h>
28 #include <ATen/core/Dict.h>
29 #include <ATen/core/Generator.h>
30 #include <ATen/core/ivalue.h>
31 #include <c10/core/Device.h>
32 #include <c10/core/thread_pool.h>
33 #include <c10/util/SmallVector.h>
34 #include <c10/util/irange.h>
35
36 namespace torch::jit {
aliasAnalysisFromSchema()37 constexpr inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
38 return c10::AliasAnalysisKind::FROM_SCHEMA;
39 }
40
aliasAnalysisConservative()41 constexpr inline c10::AliasAnalysisKind aliasAnalysisConservative() {
42 return c10::AliasAnalysisKind::CONSERVATIVE;
43 }
44
aliasAnalysisSpecialCase()45 constexpr inline c10::AliasAnalysisKind aliasAnalysisSpecialCase() {
46 return c10::AliasAnalysisKind::INTERNAL_SPECIAL_CASE;
47 }
48
49 template <class T>
make_result_list(const TypePtr & elemType)50 c10::List<T> make_result_list(const TypePtr& elemType) {
51 return c10::List<T>();
52 }
53
54 template <>
55 c10::impl::GenericList make_result_list<IValue>(const TypePtr& elemType);
56
57 // As described in https://docs.python.org/3/library/functions.html#round
58 // When a number is exactly halfway between two integers, python builtin round
59 // function will round to even number. We use round(x/2)*2 to handle the
60 // special halfway case. For positive 'x', round(x/2)*2 =
61 // round((x_e + x_r)/2)*2 = x_e + round(x_r/2)*2, where x_e is an even integer,
62 // x_r is either 0.5 of 1.5, round(x_r/2)*2 results a 0 or 2, so the final
63 // result will always be a even number. Due to symmetricity, it also applies to
64 // negative cases.
round_to_even(double a)65 inline double round_to_even(double a) {
66 return a - std::floor(a) == 0.5 ? (std::round(a * 0.5) * 2.0) : std::round(a);
67 }
68
69 // using the rules from python_arg_parser FunctionParameter::check
70 // tensor cannot have grad set, tensor must be 0 dim,
71 // and if the dest is an int the source must be integral type
72 void checkImplicitTensorToNum(const at::Tensor& t, bool toInt);
73
floordiv(int64_t a,int64_t b)74 static C10_UNUSED int64_t floordiv(int64_t a, int64_t b) {
75 if (b == 0) {
76 throw std::runtime_error("division by 0");
77 }
78 if ((a > 0) == (b > 0)) {
79 // simple case, both have same sign
80 return a / b;
81 } else {
82 // in python division rounds down, it doesn't not truncate like in c++
83 auto r = lldiv(a, b);
84 return (r.rem) ? r.quot - 1 : r.quot;
85 }
86 }
87 TORCH_API void checkDoubleInRange(double a);
floor(double a)88 static C10_UNUSED int64_t floor(double a) {
89 checkDoubleInRange(a);
90 return std::floor(a);
91 }
ceil(double a)92 static C10_UNUSED int64_t ceil(double a) {
93 checkDoubleInRange(a);
94 return std::ceil(a);
95 }
96
gcd(int64_t a,int64_t b)97 static C10_UNUSED int64_t gcd(int64_t a, int64_t b) {
98 while (b != 0) {
99 int64_t r = a % b;
100 a = b;
101 b = r;
102 }
103 // in python gcd returns non-negative values
104 return std::abs(a);
105 }
106
107 int64_t partProduct(int n, int m);
108
109 void loop(int n, int64_t& p, int64_t& r);
110
111 int nminussumofbits(int v);
112
113 int64_t factorial(int n);
114 static const double degToRad = std::acos(-1.0) / 180.0;
115 static const double radToDeg = 180.0 / std::acos(-1.0);
116 double degrees(double x);
117 double radians(double x);
118
119 // Convert an python index (which may be negative) into an index usable for a
120 // C++ container
121
122 // Equivalent to list.at(idx)
123 template <typename T>
decltype(auto)124 decltype(auto) getItem(const c10::List<T>& list, int64_t idx) {
125 const int64_t list_size = list.size();
126 const int64_t normalized_idx = normalizeIndex(idx, list_size);
127 if (normalized_idx < 0 || normalized_idx >= list_size) {
128 throw std::out_of_range("list index out of range");
129 }
130 return list.get(normalized_idx);
131 }
132
133 template <typename T>
setItem(const c10::List<T> & list,int64_t idx,T && value)134 void setItem(const c10::List<T>& list, int64_t idx, T&& value) {
135 const int64_t list_size = list.size();
136 const int64_t normalized_idx = normalizeIndex(idx, list_size);
137 if (normalized_idx < 0 || normalized_idx >= list_size) {
138 throw std::out_of_range("list index out of range");
139 }
140 list.set(normalized_idx, std::forward<T>(value));
141 }
142
143 void listAppend(Stack& stack);
144
145 void listReverse(Stack& stack);
146
147 template <typename T>
minList(Stack & stack)148 void minList(Stack& stack) {
149 c10::List<T> a = pop(stack).to<c10::List<T>>();
150 c10::List<T> b = pop(stack).to<c10::List<T>>();
151
152 size_t min_size = std::min(a.size(), b.size());
153 for (const auto i : c10::irange(min_size)) {
154 if (a[i] == b[i]) {
155 continue;
156 }
157
158 push(stack, a[i] < b[i] ? a : b);
159 return;
160 }
161
162 push(stack, b.size() < a.size() ? b : a);
163 }
164
165 template <typename T>
maxList(Stack & stack)166 void maxList(Stack& stack) {
167 c10::List<T> a = pop(stack).to<c10::List<T>>();
168 c10::List<T> b = pop(stack).to<c10::List<T>>();
169
170 size_t min_size = std::min(a.size(), b.size());
171 for (const auto i : c10::irange(min_size)) {
172 if (a[i] == b[i]) {
173 continue;
174 }
175
176 push(stack, a[i] > b[i] ? a : b);
177 return;
178 }
179
180 push(stack, b.size() > a.size() ? b : a);
181 }
182
183 void listPopImpl(Stack& stack, const char* empty_message);
184
185 void listPop(Stack& stack);
186
187 void listClear(Stack& stack);
188
189 void listDelete(Stack& stack);
190
191 void listInsert(Stack& stack);
192
193 template <typename T>
listRemove(Stack & stack)194 void listRemove(Stack& stack) {
195 T elem = pop(stack).to<T>();
196 c10::List<T> list = pop(stack).to<c10::List<T>>();
197
198 auto pos = std::find(list.begin(), list.end(), elem);
199
200 if (pos != list.end()) {
201 list.erase(pos);
202 } else {
203 AT_ERROR("list.remove(x): x not in list");
204 }
205 }
206
207 template <typename T>
listMin(Stack & stack)208 void listMin(Stack& stack) {
209 c10::List<T> list = pop(stack).to<c10::List<T>>();
210 size_t list_size = list.size();
211 if (list_size == 0) {
212 throw std::runtime_error("min() arg is an empty sequence");
213 }
214
215 T min_elem = list[0];
216 for (const auto i : c10::irange(1, list_size)) {
217 T elem = list[i];
218 min_elem = elem < min_elem ? elem : min_elem;
219 }
220
221 stack.push_back(min_elem);
222 }
223
224 template <typename T>
listMax(Stack & stack)225 void listMax(Stack& stack) {
226 c10::List<T> list = pop(stack).to<c10::List<T>>();
227 size_t list_size = list.size();
228 if (list_size == 0) {
229 throw std::runtime_error("max() arg is an empty sequence");
230 }
231
232 T max_elem = list[0];
233 for (const auto i : c10::irange(1, list_size)) {
234 T elem = list[i];
235 max_elem = elem > max_elem ? elem : max_elem;
236 }
237
238 stack.push_back(max_elem);
239 }
240
241 template <>
242 void listRemove<at::Tensor>(Stack& stack);
243
244 template <typename T>
listIndex(Stack & stack)245 void listIndex(Stack& stack) {
246 T elem = pop(stack).to<T>();
247 c10::List<T> list = pop(stack).to<c10::List<T>>();
248
249 auto pos = std::find(list.begin(), list.end(), elem);
250
251 if (pos != list.end()) {
252 push(stack, static_cast<int64_t>(std::distance(list.begin(), pos)));
253 } else {
254 AT_ERROR("'", elem, "' is not in list");
255 }
256 }
257
258 template <>
259 void listIndex<at::Tensor>(Stack& stack);
260
261 template <typename T>
listCount(Stack & stack)262 void listCount(Stack& stack) {
263 T elem = pop(stack).to<T>();
264 c10::List<T> list = pop(stack).to<c10::List<T>>();
265
266 const int64_t count = std::count(list.begin(), list.end(), elem);
267 push(stack, count);
268 }
269
270 template <>
271 void listCount<at::Tensor>(Stack& stack);
272
273 void listExtend(Stack& stack);
274
275 void listCopy(Stack& stack);
276
277 void listSelect(Stack& stack);
278
279 void listLen(Stack& stack);
280
281 template <typename T>
listEq(Stack & stack)282 void listEq(Stack& stack) {
283 c10::List<T> b = pop(stack).to<c10::List<T>>();
284 c10::List<T> a = pop(stack).to<c10::List<T>>();
285 push(stack, a == b);
286 }
287
288 template <typename T>
listNe(Stack & stack)289 void listNe(Stack& stack) {
290 c10::List<T> b = pop(stack).to<c10::List<T>>();
291 c10::List<T> a = pop(stack).to<c10::List<T>>();
292 push(stack, a != b);
293 }
294
tensor_list_equal(const c10::List<at::Tensor> & a,const c10::List<at::Tensor> & b)295 inline bool tensor_list_equal(
296 const c10::List<at::Tensor>& a,
297 const c10::List<at::Tensor>& b) {
298 if (a.size() != b.size()) {
299 return false;
300 }
301
302 for (const auto i : c10::irange(a.size())) {
303 const at::Tensor& a_element = a[i];
304 const at::Tensor& b_element = b[i];
305 // This preserves Python's semantics, which uses eq() to compare two
306 // elements, then passes the result to bool().
307 // see: https://docs.python.org/3.4/reference/datamodel.html#object.__ge__
308 const auto cmp_result = a_element.eq(b_element);
309 if (!at::native::is_nonzero(cmp_result)) {
310 return false;
311 }
312 }
313
314 return true;
315 }
316
317 // Specialization for at::Tensor, since it doesn't define operator==
318 template <>
319 void listEq<at::Tensor>(Stack& stack);
320
321 // Specialization for at::Tensor, since it doesn't define operator==
322 template <>
323 void listNe<at::Tensor>(Stack& stack);
324
325 void listList(Stack& stack);
326
327 template <typename T>
listContains(Stack & stack)328 void listContains(Stack& stack) {
329 auto key = pop(stack).to<T>();
330 auto list = pop(stack).to<c10::List<T>>();
331 // NOLINTNEXTLINE(performance-implicit-conversion-in-loop)
332 for (const T& item : list) {
333 if (item == key) {
334 push(stack, true);
335 return;
336 }
337 }
338 push(stack, false);
339 }
340
341 void listAdd(Stack& stack);
342
343 void listInplaceAdd(Stack& stack);
344
345 void listMulIntLeftInPlace(Stack& stack);
346
347 void listMulIntLeft(Stack& stack);
348
349 void listMulIntRight(Stack& stack);
350
351 void listSlice(Stack& stack);
352
353 template <typename T>
listSort(Stack & stack)354 void listSort(Stack& stack) {
355 bool reverse = pop(stack).toBool();
356 c10::List<T> list = pop(stack).to<c10::List<T>>();
357 std::sort(list.begin(), list.end(), [reverse](const T& a, const T& b) {
358 // FBCode errors without this check - "strict weak ordering"
359 // TODO: remove when possible, since it just slows down
360 // sorting and doesn't do anything useful
361 if (a == b) {
362 return false;
363 }
364 return (a < b) != reverse;
365 });
366 }
367
368 // Specialization for at::Tensor
369 template <>
370 void listSort<at::Tensor>(Stack& stack);
371
372 template <typename T>
listCopyAndSort(Stack & stack)373 void listCopyAndSort(Stack& stack) {
374 c10::List<T> list = pop(stack).to<c10::List<T>>();
375 auto list_copied = list.copy();
376 std::sort(list_copied.begin(), list_copied.end(), [](const T& a, const T& b) {
377 // "strict weak ordering" issue - see other sort
378 if (a == b) {
379 return false;
380 }
381 return a < b;
382 });
383 push(stack, list_copied);
384 }
385
386 // Specialization for at::Tensor
387 template <>
388 void listCopyAndSort<at::Tensor>(Stack& stack);
389
390 void listSetItem(Stack& stack);
391
392 struct OperatorGeneratorArgs {
393 const char* schema_str;
394 bool isOperationCreator;
395 union {
396 void (*operation)(Stack&);
397 OperationCreator operationCreator;
398 };
399 AliasAnalysisKind aliasAnalysis;
400
OperatorGeneratorArgsOperatorGeneratorArgs401 explicit constexpr OperatorGeneratorArgs(
402 torch::detail::SelectiveStr<true> schema_str,
403 void (*op)(Stack&),
404 AliasAnalysisKind aa)
405 : schema_str(schema_str),
406 isOperationCreator(false),
407 operation(op),
408 aliasAnalysis(aa) {}
409
OperatorGeneratorArgsOperatorGeneratorArgs410 explicit constexpr OperatorGeneratorArgs(
411 torch::detail::SelectiveStr<true> schema_str,
412 OperationCreator opCreator,
413 AliasAnalysisKind aa)
414 : schema_str(schema_str),
415 isOperationCreator(true),
416 operationCreator(opCreator),
417 aliasAnalysis(aa) {}
418
419 template <typename... Args>
OperatorGeneratorArgsOperatorGeneratorArgs420 explicit constexpr OperatorGeneratorArgs(
421 torch::detail::SelectiveStr<false>,
422 Args...)
423 : schema_str(nullptr),
424 isOperationCreator(false),
425 operation(nullptr),
426 aliasAnalysis(AliasAnalysisKind::INTERNAL_SPECIAL_CASE) {}
427 };
428
429 #define DEFINE_GENERIC_BINARY_OP( \
430 aten_op, op, int_float_result, complex_result) \
431 OperatorGeneratorArgs( \
432 TORCH_SELECTIVE_SCHEMA(#aten_op \
433 ".int_int(int a, int b) -> " #int_float_result), \
434 [](Stack& stack) { \
435 int64_t a, b; \
436 pop(stack, a, b); \
437 push(stack, op); \
438 }, \
439 aliasAnalysisFromSchema()), \
440 OperatorGeneratorArgs( \
441 TORCH_SELECTIVE_SCHEMA( \
442 #aten_op \
443 ".float_float(float a, float b) -> " #int_float_result), \
444 [](Stack& stack) { \
445 double a, b; \
446 pop(stack, a, b); \
447 push(stack, op); \
448 }, \
449 aliasAnalysisFromSchema()), \
450 OperatorGeneratorArgs( \
451 TORCH_SELECTIVE_SCHEMA( \
452 #aten_op \
453 ".complex_complex(complex a, complex b) -> " #complex_result), \
454 [](Stack& stack) { \
455 c10::complex<double> a, b; \
456 pop(stack, a, b); \
457 push(stack, op); \
458 }, \
459 aliasAnalysisFromSchema())
460
461 // define implementations for primitive number ops
462 #define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \
463 OperatorGeneratorArgs( \
464 TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a, int b) -> " #int_result), \
465 [](Stack& stack) { \
466 int64_t a, b; \
467 pop(stack, a, b); \
468 push(stack, int_op); \
469 }, \
470 aliasAnalysisFromSchema()), \
471 OperatorGeneratorArgs( \
472 TORCH_SELECTIVE_SCHEMA( \
473 #aten_op ".float(float a, float b) -> " #float_result), \
474 [](Stack& stack) { \
475 double a, b; \
476 pop(stack, a, b); \
477 push(stack, float_op); \
478 }, \
479 aliasAnalysisFromSchema())
480
481 #define DEFINE_INT_FLOAT_OP(aten_op, op, result) \
482 OperatorGeneratorArgs( \
483 TORCH_SELECTIVE_SCHEMA(#aten_op \
484 ".int_float(int a, float b) -> " #result), \
485 [](Stack& stack) { \
486 int64_t a; \
487 double b; \
488 pop(stack, a, b); \
489 push(stack, op); \
490 }, \
491 aliasAnalysisFromSchema()), \
492 OperatorGeneratorArgs( \
493 TORCH_SELECTIVE_SCHEMA(#aten_op \
494 ".float_int(float a, int b) -> " #result), \
495 [](Stack& stack) { \
496 double a; \
497 int64_t b; \
498 pop(stack, a, b); \
499 push(stack, op); \
500 }, \
501 aliasAnalysisFromSchema())
502
503 #define DEFINE_INT_OP(aten_op, op) \
504 OperatorGeneratorArgs( \
505 TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a, int b) -> int"), \
506 [](Stack& stack) { \
507 int64_t a, b; \
508 pop(stack, a, b); \
509 push(stack, op); /* NOLINT(hicpp-signed-bitwise) */ \
510 }, \
511 aliasAnalysisFromSchema())
512
513 #define DEFINE_STR_CMP_OP(aten_op, op) \
514 OperatorGeneratorArgs( \
515 TORCH_SELECTIVE_SCHEMA(#aten_op ".str(str a, str b) -> bool"), \
516 [](Stack& stack) { \
517 auto b = pop(stack).toStringRef(); \
518 auto a = pop(stack).toStringRef(); \
519 push(stack, op); \
520 }, \
521 aliasAnalysisFromSchema())
522
523 // define a primitive op over Scalar operands.
524 // it's necessary to register this overload following
525 // int/float variations to avoid trapping Scalar args
526 // in unintended implicit conversions
527 #define DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION_GENERIC( \
528 aten_op, int_op, float_op, result, string_val) \
529 OperatorGeneratorArgs( \
530 TORCH_SELECTIVE_SCHEMA(#aten_op string_val \
531 "(Scalar a, Scalar b) -> " #result), \
532 [](Stack& stack) { \
533 IValue x, y; \
534 pop(stack, x, y); \
535 if (x.isDouble()) { \
536 if (y.isDouble()) { \
537 double a = x.toDouble(); \
538 double b = y.toDouble(); \
539 push(stack, float_op); \
540 } else { \
541 double a = x.toDouble(); \
542 int64_t b = y.toInt(); \
543 push(stack, float_op); \
544 } \
545 } else { \
546 if (y.isDouble()) { \
547 int64_t a = x.toInt(); \
548 double b = y.toDouble(); \
549 push(stack, float_op); \
550 } else { \
551 int64_t a = x.toInt(); \
552 int64_t b = y.toInt(); \
553 push(stack, int_op); \
554 } \
555 } \
556 }, \
557 aliasAnalysisFromSchema())
558
559 #define DEFINE_SCALAR_BINARY_OP(aten_op, int_op, float_op, result) \
560 DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION_GENERIC( \
561 aten_op, int_op, float_op, result, "")
562
563 #define DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION( \
564 aten_op, int_op, float_op, result) \
565 DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION_GENERIC( \
566 aten_op, int_op, float_op, result, ".Scalar_Scalar")
567
568 #define DEFINE_BINARY_OP(aten_op, op) \
569 DEFINE_GENERIC_OP(aten_op, op, op, int, float), \
570 DEFINE_INT_FLOAT_OP(aten_op, op, float), \
571 DEFINE_SCALAR_BINARY_OP(aten_op, op, op, Scalar)
572
573 #define DEFINE_BINARY_FLOAT_OP(aten_op, op) \
574 DEFINE_GENERIC_OP(aten_op, op, op, float, float), \
575 DEFINE_INT_FLOAT_OP(aten_op, op, float), \
576 DEFINE_SCALAR_BINARY_OP(aten_op, op, op, float)
577
578 #define DEFINE_COMPARISON_OP(aten_op, op) \
579 DEFINE_GENERIC_OP(aten_op, op, op, bool, bool), \
580 DEFINE_INT_FLOAT_OP(aten_op, op, bool), \
581 DEFINE_SCALAR_BINARY_OP(aten_op, op, op, bool), \
582 DEFINE_STR_CMP_OP(aten_op, op)
583
584 #define DEFINE_UNARY_INT_OP(aten_op, op, result) \
585 OperatorGeneratorArgs( \
586 TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a) -> " #result), \
587 [](Stack& stack) { \
588 int64_t a; \
589 pop(stack, a); \
590 push(stack, op); \
591 }, \
592 aliasAnalysisFromSchema())
593
594 #define DEFINE_UNARY_FLOAT_OP(aten_op, op, result) \
595 OperatorGeneratorArgs( \
596 TORCH_SELECTIVE_SCHEMA(#aten_op ".float(float a) -> " #result), \
597 [](Stack& stack) { \
598 double a; \
599 pop(stack, a); \
600 push(stack, op); \
601 }, \
602 aliasAnalysisFromSchema())
603
604 #define DEFINE_UNARY_OP(aten_op, op, int_result, float_result) \
605 DEFINE_UNARY_INT_OP(aten_op, op, int_result), \
606 DEFINE_UNARY_FLOAT_OP(aten_op, op, float_result), \
607 OperatorGeneratorArgs( \
608 TORCH_SELECTIVE_SCHEMA(#aten_op ".Scalar(Scalar a) -> Scalar"), \
609 [](Stack& stack) { \
610 IValue x; \
611 pop(stack, x); \
612 if (x.isDouble()) { \
613 double a = x.toDouble(); \
614 push(stack, static_cast<float_result>(op)); \
615 } else { \
616 int64_t a = x.toInt(); \
617 push(stack, static_cast<int_result>(op)); \
618 } \
619 }, \
620 aliasAnalysisFromSchema())
621 #define DEFINE_BOOL_OP(aten_op, op) \
622 OperatorGeneratorArgs( \
623 TORCH_SELECTIVE_SCHEMA(#aten_op ".bool(bool a, bool b) -> bool"), \
624 [](Stack& stack) { \
625 bool a, b; \
626 pop(stack, a, b); \
627 push(stack, op); \
628 }, \
629 aliasAnalysisFromSchema())
630 #define DEFINE_STRING_OP(op_name, string_op, result) \
631 OperatorGeneratorArgs( \
632 TORCH_SELECTIVE_SCHEMA(#op_name ".str(str a, str b) ->" #result), \
633 [](Stack& stack) { \
634 auto b = pop(stack).toStringRef(); \
635 auto a = pop(stack).toStringRef(); \
636 push(stack, string_op); \
637 }, \
638 aliasAnalysisFromSchema())
639
640 //-----------------------------------------------------------------------------
641 //-----------------------------------------------------------------------------
642 //-----------------------------------------------------------------------------
643 //-----------------------------------------------------------------------------
644 #define DEFINE_UNARY_COMPLEX_OP(aten_op, op, result) \
645 OperatorGeneratorArgs( \
646 TORCH_SELECTIVE_SCHEMA(#aten_op ".complex(complex a) -> " #result), \
647 [](Stack& stack) { \
648 c10::complex<double> a; \
649 pop(stack, a); \
650 push(stack, op); \
651 }, \
652 aliasAnalysisFromSchema())
653
654 // Some complex unary ops (like abs, angle) return real valued output, but most
655 // other unary ops return complex valued output. So, this macro is used in the
656 // former case where we can explicitly pass complex_result_cast argument, which
657 // is set to c10::complex<float> in the macro `DEFINE_UNARY_OP_WITH_COMPLEX`
658 // defined below.
659 #define DEFINE_UNARY_OP_WITH_COMPLEX_CAST( \
660 aten_op, \
661 op, \
662 int_result, \
663 float_result, \
664 complex_result, \
665 complex_result_cast) \
666 DEFINE_UNARY_INT_OP(aten_op, op, int_result), \
667 DEFINE_UNARY_FLOAT_OP(aten_op, op, float_result), \
668 DEFINE_UNARY_COMPLEX_OP(aten_op, op, complex_result), \
669 OperatorGeneratorArgs( \
670 TORCH_SELECTIVE_SCHEMA(#aten_op ".Scalar(Scalar a) -> Scalar"), \
671 [](Stack& stack) { \
672 IValue x; \
673 pop(stack, x); \
674 if (x.isDouble()) { \
675 double a = x.toDouble(); \
676 push(stack, static_cast<float_result>(op)); \
677 } else if (x.isComplexDouble()) { \
678 c10::complex<double> a = x.toComplexDouble(); \
679 push(stack, static_cast<complex_result_cast>(op)); \
680 } else { \
681 int64_t a = x.toInt(); \
682 push(stack, static_cast<int_result>(op)); \
683 } \
684 }, \
685 aliasAnalysisFromSchema())
686
687 #define DEFINE_UNARY_OP_WITH_COMPLEX(aten_op, op, int_result, float_result) \
688 DEFINE_UNARY_OP_WITH_COMPLEX_CAST( \
689 aten_op, op, int_result, float_result, complex, c10::complex<double>)
690
691 #define DEFINE_GENERIC_OP_WITH_COMPLEX( \
692 aten_op, \
693 int_op, \
694 float_op, \
695 complex_op, \
696 int_result, \
697 float_result, \
698 complex_result) \
699 OperatorGeneratorArgs( \
700 TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a, int b) -> " #int_result), \
701 [](Stack& stack) { \
702 int64_t a, b; \
703 pop(stack, a, b); \
704 push(stack, int_op); \
705 }, \
706 aliasAnalysisFromSchema()), \
707 OperatorGeneratorArgs( \
708 TORCH_SELECTIVE_SCHEMA( \
709 #aten_op ".complex(complex a, complex b) -> " #complex_result), \
710 [](Stack& stack) { \
711 c10::complex<double> a, b; \
712 pop(stack, a, b); \
713 push(stack, complex_op); \
714 }, \
715 aliasAnalysisFromSchema()), \
716 OperatorGeneratorArgs( \
717 TORCH_SELECTIVE_SCHEMA( \
718 #aten_op ".float(float a, float b) -> " #float_result), \
719 [](Stack& stack) { \
720 double a, b; \
721 pop(stack, a, b); \
722 push(stack, float_op); \
723 }, \
724 aliasAnalysisFromSchema())
725
726 #define DEFINE_INT_COMPLEX_OP(aten_op, op, result) \
727 OperatorGeneratorArgs( \
728 TORCH_SELECTIVE_SCHEMA(#aten_op \
729 ".int_complex(int a, complex b) -> " #result), \
730 [](Stack& stack) { \
731 int64_t a; \
732 c10::complex<double> b; \
733 pop(stack, a, b); \
734 push(stack, op); \
735 }, \
736 aliasAnalysisFromSchema()), \
737 OperatorGeneratorArgs( \
738 TORCH_SELECTIVE_SCHEMA( \
739 #aten_op ".complex_int(complex a, int b) -> " #result), \
740 [](Stack& stack) { \
741 c10::complex<double> a; \
742 int64_t b; \
743 pop(stack, a, b); \
744 push(stack, op); \
745 }, \
746 aliasAnalysisFromSchema())
747
748 #define DEFINE_FLOAT_COMPLEX_OP(aten_op, op, result) \
749 OperatorGeneratorArgs( \
750 TORCH_SELECTIVE_SCHEMA( \
751 #aten_op ".float_complex(float a, complex b) -> " #result), \
752 [](Stack& stack) { \
753 double a; \
754 c10::complex<double> b; \
755 pop(stack, a, b); \
756 push(stack, op); \
757 }, \
758 aliasAnalysisFromSchema()), \
759 OperatorGeneratorArgs( \
760 TORCH_SELECTIVE_SCHEMA( \
761 #aten_op ".complex_float(complex a, float b) -> " #result), \
762 [](Stack& stack) { \
763 c10::complex<double> a; \
764 double b; \
765 pop(stack, a, b); \
766 push(stack, op); \
767 }, \
768 aliasAnalysisFromSchema())
769
770 #define DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX_AVOID_COLLISION_GENERIC( \
771 aten_op, int_op, float_op, complex_op, result, string_val) \
772 OperatorGeneratorArgs( \
773 TORCH_SELECTIVE_SCHEMA(#aten_op string_val \
774 "(Scalar a, Scalar b) -> " #result), \
775 [](Stack& stack) { \
776 IValue x, y; \
777 pop(stack, x, y); \
778 if (x.isComplexDouble()) { \
779 c10::complex<double> a = x.toComplexDouble(); \
780 if (y.isComplexDouble()) { \
781 c10::complex<double> b = y.toComplexDouble(); \
782 push(stack, complex_op); \
783 } else if (y.isDouble()) { \
784 double b = y.toDouble(); \
785 push(stack, complex_op); \
786 } else { \
787 int64_t b = y.toInt(); \
788 push(stack, complex_op); \
789 } \
790 } else if (x.isDouble()) { \
791 double a = x.toDouble(); \
792 if (y.isComplexDouble()) { \
793 c10::complex<double> b = y.toComplexDouble(); \
794 push(stack, complex_op); \
795 } else if (y.isDouble()) { \
796 double b = y.toDouble(); \
797 push(stack, float_op); \
798 } else { \
799 int64_t b = y.toInt(); \
800 push(stack, float_op); \
801 } \
802 } else { \
803 int64_t a = x.toInt(); \
804 if (y.isComplexDouble()) { \
805 c10::complex<double> b = y.toComplexDouble(); \
806 push(stack, complex_op); \
807 } else if (y.isDouble()) { \
808 double b = y.toDouble(); \
809 push(stack, float_op); \
810 } else { \
811 int64_t b = y.toInt(); \
812 push(stack, int_op); \
813 } \
814 } \
815 }, \
816 aliasAnalysisFromSchema())
817
818 #define DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX_WITHOUT_INT_COMPLEX_PAIR( \
819 aten_op, int_op, float_op, complex_op, result) \
820 OperatorGeneratorArgs( \
821 TORCH_SELECTIVE_SCHEMA(#aten_op "(Scalar a, Scalar b) -> " #result), \
822 [](Stack& stack) { \
823 IValue x, y; \
824 pop(stack, x, y); \
825 if (x.isComplexDouble()) { \
826 c10::complex<double> a = x.toComplexDouble(); \
827 if (y.isComplexDouble()) { \
828 c10::complex<double> b = y.toComplexDouble(); \
829 push(stack, complex_op); \
830 } else if (y.isDouble()) { \
831 double b = y.toDouble(); \
832 push(stack, complex_op); \
833 } \
834 } else if (x.isDouble()) { \
835 double a = x.toDouble(); \
836 if (y.isComplexDouble()) { \
837 c10::complex<double> b = y.toComplexDouble(); \
838 push(stack, complex_op); \
839 } else if (y.isDouble()) { \
840 double b = y.toDouble(); \
841 push(stack, float_op); \
842 } else { \
843 int64_t b = y.toInt(); \
844 push(stack, float_op); \
845 } \
846 } else { \
847 int64_t a = x.toInt(); \
848 if (y.isDouble()) { \
849 double b = y.toDouble(); \
850 push(stack, float_op); \
851 } else if (y.isInt()) { \
852 int64_t b = y.toInt(); \
853 push(stack, int_op); \
854 } \
855 } \
856 }, \
857 aliasAnalysisFromSchema())
858
859 #define DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX( \
860 aten_op, int_op, float_op, complex_op, result) \
861 DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX_AVOID_COLLISION_GENERIC( \
862 aten_op, int_op, float_op, complex_op, result, "")
863
864 #define DEFINE_BINARY_OP_WITH_COMPLEX(aten_op, op) \
865 DEFINE_GENERIC_OP_WITH_COMPLEX(aten_op, op, op, op, int, float, complex), \
866 DEFINE_INT_COMPLEX_OP(aten_op, op, complex), \
867 DEFINE_FLOAT_COMPLEX_OP(aten_op, op, complex), \
868 DEFINE_INT_FLOAT_OP(aten_op, op, float), \
869 DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX(aten_op, op, op, op, Scalar)
870
871 #define DEFINE_COMPARISON_OP_WITH_COMPLEX(aten_op, op) \
872 DEFINE_GENERIC_OP_WITH_COMPLEX(aten_op, op, op, op, bool, bool, bool), \
873 DEFINE_INT_FLOAT_OP(aten_op, op, bool), \
874 DEFINE_FLOAT_COMPLEX_OP(aten_op, op, bool), \
875 DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX_WITHOUT_INT_COMPLEX_PAIR( \
876 aten_op, op, op, op, bool), \
877 DEFINE_STR_CMP_OP(aten_op, op)
878
879 TORCH_API at::Generator make_generator_for_device(
880 c10::Device device,
881 std::optional<int64_t> seed = std::nullopt);
882
883 } // namespace torch::jit
884