xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/static/native_ops.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/inliner.h>
2 #include <torch/csrc/jit/runtime/static/impl.h>
3 #include <torch/csrc/jit/runtime/static/ops.h>
4 
5 #include <ATen/CPUFunctions.h>
6 #include <ATen/NativeFunctions.h>
7 #include <ATen/ScalarOps.h>
8 #include <ATen/TensorUtils.h>
9 #include <ATen/native/IndexingUtils.h>
10 #include <ATen/native/NonSymbolicBC.h>
11 #include <ATen/native/Resize.h>
12 #include <ATen/native/TensorAdvancedIndexing.h>
13 #include <c10/util/intrusive_ptr.h>
14 #include <c10/util/irange.h>
15 #include <c10/util/ssize.h>
16 #include <torch/csrc/jit/ir/ir.h>
17 #include <torch/csrc/jit/mobile/promoted_prim_ops.h>
18 #include <torch/csrc/jit/runtime/register_ops_utils.h>
19 #include <torch/csrc/jit/runtime/vararg_functions.h>
20 
21 namespace {
22 constexpr auto createBorrowedIValue =
23     c10::MaybeOwnedTraits<c10::IValue>::createBorrow;
24 } // namespace
25 namespace torch::jit {
26 
27 namespace {
28 
boxInputs(const ProcessedNode & pnode)29 std::vector<IValue> boxInputs(const ProcessedNode& pnode) {
30   std::vector<IValue> result;
31   for (const auto i : c10::irange(pnode.num_inputs())) {
32     result.push_back(pnode.Input(i));
33   }
34   return result;
35 }
36 
37 } // namespace
38 
39 C10_DEFINE_REGISTRY(SRNativeOperatorRegistry, SROperatorFunctor);
40 
nativeOpIsRegistered(const c10::Symbol & op_name)41 bool nativeOpIsRegistered(const c10::Symbol& op_name) {
42   const std::string name(op_name.toQualString());
43   return SRNativeOperatorRegistry()->Has(name);
44 }
45 
getNativeOperation(Node * n)46 SROperator getNativeOperation(Node* n) {
47   auto op_name = n->kind().toQualString();
48   if (SRNativeOperatorRegistry()->Has(op_name)) {
49     return SRNativeOperatorRegistry()->Create(op_name)->Generate(n);
50   }
51   return nullptr;
52 }
53 
54 REGISTER_NATIVE_OPERATOR_FUNCTOR(
55     prim::TupleConstruct,
56     prim_TupleConstruct,
__anon75e5f0510302(Node* n) 57     [](Node* n) -> SROperator {
58       if (!sr_schema_check_kind(n, prim::TupleConstruct)) {
59         return nullptr;
60       }
61       return [](ProcessedNode* p_node) {
62         // prepare inputs
63         auto stack = boxInputs(*p_node);
64         // run op
65         auto* node = p_node->node();
66         const auto& type = node->output()->type()->expect<TupleType>();
67         if (type->name().has_value()) {
68           namedTupleConstruct(stack, type, node->inputs().size());
69         } else {
70           tupleConstruct(stack, node->inputs().size());
71         }
72         // put output back
73         p_node->Output(0) = std::move(stack[0]);
74       };
75     });
76 
77 REGISTER_NATIVE_OPERATOR_FUNCTOR(
78     prim::TupleUnpack,
79     prim_TupleUnpack,
__anon75e5f0510502(Node* n) 80     [](Node* n) -> SROperator {
81       if (!sr_schema_check_kind(n, prim::TupleUnpack)) {
82         return nullptr;
83       }
84       return [](ProcessedNode* p_node) {
85         const auto& elems = p_node->Input(0).toTupleRef().elements();
86         const size_t num_outputs = p_node->outputs().size();
87         TORCH_CHECK(
88             num_outputs == elems.size(),
89             "Number of outputs must match number of tuple elements.")
90         for (size_t i = 0; i < num_outputs; ++i) {
91           p_node->Output(i) = elems[i];
92         }
93       };
94     });
95 
96 REGISTER_NATIVE_OPERATOR_FUNCTOR(
97     prim::DictConstruct,
98     prim_DictConstruct,
__anon75e5f0510702(Node* n) 99     [](Node* n) -> SROperator {
100       if (!sr_schema_check_kind(n, prim::DictConstruct)) {
101         return nullptr;
102       }
103       auto dict_type = n->output()->type()->expect<DictType>();
104       const auto num_inputs = n->inputs().size();
105       TORCH_DCHECK_EQ(num_inputs % 2, 0);
106       return [dict_type = std::move(dict_type),
107               num_inputs,
108               dict_size = num_inputs / 2](ProcessedNode* p_node) {
109         auto result = c10::impl::GenericDict(
110             dict_type->containedType(0), dict_type->containedType(1));
111         result.reserve(dict_size);
112         for (size_t i = 0; i < num_inputs; i += 2) {
113           const auto& key = p_node->Input(i);
114           const auto& value = p_node->Input(i + 1);
115           result.insert_or_assign(key, value);
116         }
117         p_node->Output(0) = result;
118       };
119     });
120 
121 // See [Borrowed IValue Outputs]
122 REGISTER_NATIVE_OPERATOR_FUNCTOR(
123     static_runtime::dict_unpack,
124     static_runtime_dict_unpack,
__anon75e5f0510902(Node* n) 125     [](Node* n) -> SROperator {
126       if (!sr_schema_check(n, "static_runtime::dict_unpack(...) -> ...")) {
127         return nullptr;
128       }
129       return [](ProcessedNode* p_node) {
130         DCHECK(
131             static_cast<size_t>(p_node->num_inputs() - 1) ==
132             p_node->outputs().size());
133         auto dict = p_node->Input(0).toGenericDict();
134         const auto num_inputs = p_node->num_inputs();
135         for (size_t i = 1; i < num_inputs; ++i) {
136           const auto& key = p_node->Input(i);
137           auto value = dict.find(key);
138           TORCH_CHECK(value != dict.end(), "Key not in dict: ", key);
139           p_node->Output(i - 1) = createBorrowedIValue(value->value());
140         }
141       };
142     });
143 
__anon75e5f0510b02(Node* n) 144 REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::__getitem__, aten_getitem, [](Node* n) -> SROperator {
145   if (!sr_schema_check(
146           n,
147           // TODO: "aten::__getitem__.str(str s, int index) -> str",
148           "aten::__getitem__.t(t[](a) list, int idx) -> t(*)",
149           "aten::__getitem__.Dict_str(Dict(str, t) self, str key) -> t(*)",
150           "aten::__getitem__.Dict_int(Dict(int, t) self, int key) -> t(*)",
151           "aten::__getitem__.Dict_bool(Dict(bool, t) self, bool key) -> t(*)",
152           "aten::__getitem__.Dict_float(Dict(float, t) self, float key) -> t(*)",
153           "aten::__getitem__.Dict_complex(Dict(complex, t) self, complex key) -> t(*)",
154           "aten::__getitem__.Dict_Tensor(Dict(Tensor, t) self, Tensor key) -> t(*)")) {
155     return nullptr;
156   }
157 
158   if (n->inputs().size() != 2) {
159     return nullptr;
160   }
161 
162   if (n->input(0)->type()->castRaw<DictType>()) {
163     return [](ProcessedNode* p_node) {
164       auto dict = p_node->Input(0).toGenericDict();
165       const auto& key = p_node->Input(1);
166       auto value = dict.find(key);
167       TORCH_CHECK(value != dict.end(), "Key not in dict: ", key);
168       p_node->Output(0) = value->value();
169     };
170   } else if (n->input(0)->type()->castRaw<ListType>()) {
171     return [](ProcessedNode* p_node) {
172       const auto& list = p_node->Input(0).toList();
173       auto idx = p_node->Input(1).toInt();
174       p_node->Output(0) = getItem(list, idx);
175     };
176   }
177 
178   // TODO(T98581096): make __getitem__ work for other container types
179   return nullptr;
180 });
181 
182 REGISTER_NATIVE_OPERATOR_FUNCTOR(
183     prim::ListConstruct,
184     prim_ListConstruct,
__anon75e5f0510e02(Node* n) 185     [](Node* n) -> SROperator {
186       if (!sr_schema_check_kind(n, prim::ListConstruct)) {
187         return nullptr;
188       }
189       return [](ProcessedNode* p_node) {
190         // prepare inputs
191         auto stack = boxInputs(*p_node);
192         // run op
193         listConstruct(
194             stack,
195             p_node->node()->output()->type()->expectRef<ListType>(),
196             p_node->num_inputs());
197         // put output back
198         p_node->Output(0) = std::move(stack[0]);
199       };
200     });
201 
202 REGISTER_NATIVE_OPERATOR_FUNCTOR(
203     prim::ListUnpack,
204     prim_ListUnpack,
__anon75e5f0511002(Node* n) 205     [](Node* n) -> SROperator {
206       if (!sr_schema_check_kind(n, prim::ListUnpack)) {
207         return nullptr;
208       }
209       const auto num_outputs = n->outputs().size();
210       return [num_outputs](ProcessedNode* p_node) {
211         const auto list = p_node->Input(0).toListRef();
212         TORCH_CHECK(
213             list.size() == num_outputs,
214             "Expected ",
215             num_outputs,
216             " elements in list but got ",
217             list.size());
218         for (const auto i : c10::irange(num_outputs)) {
219           p_node->Output(i) = list[i];
220         }
221       };
222     });
223 
224 REGISTER_NATIVE_OPERATOR_FUNCTOR(
225     aten::append,
226     aten_append,
__anon75e5f0511202(Node* n) 227     [](Node* n) -> SROperator {
228       if (!sr_schema_check(
229               n, "aten::append.t(t[](a!) self, t(c -> *) el) -> t[](a!)")) {
230         return nullptr;
231       }
232       return [](ProcessedNode* p_node) {
233         auto list = p_node->Input(0).toList();
234         list.push_back(p_node->Input(1));
235       };
236     });
237 
238 REGISTER_NATIVE_OPERATOR_FUNCTOR(
239     aten::list,
240     aten_list,
__anon75e5f0511402(Node* n) 241     [](Node* n) -> SROperator {
242       if (n->matches(torch::schema("aten::list(str t) -> str[]"))) {
243         return [](ProcessedNode* p_node) {
244           const auto str = p_node->Input(0).toStringRef();
245           c10::List<std::string> chars;
246           chars.reserve(str.size());
247           for (auto c : str) {
248             chars.emplace_back(1, c);
249           }
250           p_node->Output(0) = std::move(chars);
251         };
252       }
253 
254       if (n->matches(torch::schema("aten::list.t(t[] l) -> t[]"))) {
255         return [](ProcessedNode* p_node) {
256           const auto input = p_node->Input(0).toList();
257           p_node->Output(0) = input.copy();
258         };
259       }
260 
261       LogAndDumpSchema(n);
262       return nullptr;
263     });
264 
265 REGISTER_NATIVE_OPERATOR_FUNCTOR(
266     aten::numel,
267     aten_numel,
__anon75e5f0511702(Node* n) 268     [](Node* n) -> SROperator {
269       if (!sr_schema_check(n, "aten::numel(Tensor self) -> int")) {
270         return nullptr;
271       }
272       return [](ProcessedNode* p_node) {
273         const auto& arg = p_node->Input(0).toTensor();
274         p_node->Output(0) = arg.numel();
275       };
276     });
277 
278 REGISTER_NATIVE_OPERATOR_FUNCTOR(
279     aten::cpu,
280     aten_cpu,
__anon75e5f0511902(Node* n) 281     [](Node* n) -> SROperator {
282       if (!sr_schema_check(n, "aten::cpu(Tensor self) -> Tensor")) {
283         return nullptr;
284       }
285       return [](ProcessedNode* p_node) {
286         const auto& arg = p_node->Input(0).toTensor();
287         p_node->Output(0) = arg.cpu();
288       };
289     });
290 
291 REGISTER_NATIVE_OPERATOR_FUNCTOR(
292     aten::__range_length,
293     aten_range_length,
__anon75e5f0511b02(Node* n) 294     [](Node* n) -> SROperator {
295       if (!sr_schema_check(
296               n, "aten::__range_length(int lo, int hi, int step) -> int")) {
297         return nullptr;
298       }
299       return [](ProcessedNode* p_node) {
300         auto lo = p_node->Input(0).toInt();
301         auto hi = p_node->Input(1).toInt();
302         auto step = p_node->Input(2).toInt();
303         // error handling when step_val == 0 during runtime
304         if (step == 0) {
305           throw std::runtime_error("range() arg 3 must not be zero");
306         }
307         if (step > 0 && lo < hi) {
308           p_node->Output(0) = 1 + (hi - 1 - lo) / step;
309         } else if (step < 0 && lo > hi) {
310           p_node->Output(0) = 1 + (lo - 1 - hi) / (0 - step);
311         } else {
312           p_node->Output(0) = 0;
313         }
314       };
315     });
316 
__anon75e5f0511d02(Node* n) 317 REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::index_put, aten_index_put, [](Node* n) -> SROperator {
318   if (n->matches(torch::schema(
319           "aten::index_put(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor")) ||
320       n->matches(torch::schema(
321           "aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"))) {
322     return [](ProcessedNode* p_node) {
323       const auto& self = p_node->Input(0).toTensor();
324       const auto& indices =
325           at::native::toListOfOptionalTensors(p_node->Input(1).toListRef());
326       const auto& values = p_node->Input(2).toTensor();
327       const auto accumulate = p_node->Input(3).toBool();
328       p_node->Output(0) =
329           at::native::index_put(self, indices, values, accumulate);
330     };
331   }
332 
333   LogAndDumpSchema(n);
334   return nullptr;
335 });
336 
337 REGISTER_NATIVE_OPERATOR_FUNCTOR(
338     aten::item,
339     aten_item,
__anon75e5f0511f02(Node* n) 340     [](Node* n) -> SROperator {
341       if (!sr_schema_check(n, "aten::item(Tensor self) -> Scalar")) {
342         return nullptr;
343       }
344       return [](ProcessedNode* p_node) {
345         const auto& self = p_node->Input(0).toTensor();
346         p_node->Output(0) = at::native::item(self);
347       };
348     });
349 
350 REGISTER_NATIVE_OPERATOR_FUNCTOR(
351     prim::GetAttr,
352     prim_GetAttr,
__anon75e5f0512102(Node* n) 353     [](Node* n) -> SROperator {
354       if (!sr_schema_check_kind(n, prim::GetAttr)) {
355         return nullptr;
356       }
357       return [](ProcessedNode* p_node) {
358         auto& module = p_node->Input(0).toObjectRef();
359         Node* node = p_node->node();
360         const auto& type = node->input()->type()->expectRef<ClassType>();
361         const auto& field = node->s(attr::name);
362         const auto slot = type.getAttributeSlot(field);
363         p_node->Output(0) = module.getSlot(slot);
364       };
365     });
366 
367 REGISTER_NATIVE_OPERATOR_FUNCTOR(
368     prim::SetAttr,
369     prim_SetAttr,
__anon75e5f0512302(Node* n) 370     [](Node* n) -> SROperator {
371       if (!sr_schema_check_kind(n, prim::SetAttr)) {
372         return nullptr;
373       }
374       return [](ProcessedNode* p_node) {
375         auto& module = p_node->Input(0).toObjectRef();
376         Node* node = p_node->node();
377         const auto& type = node->inputs()[0]->type()->expectRef<ClassType>();
378         const auto& field = node->s(attr::name);
379         const auto slot = type.getAttributeSlot(field);
380         module.setSlot(slot, p_node->Input(1));
381       };
382     });
383 
384 REGISTER_NATIVE_OPERATOR_FUNCTOR(
385     aten::transpose,
386     aten_transpose,
__anon75e5f0512502(Node* n) 387     [](Node* n) -> SROperator {
388       if (!n->matches(torch::schema(
389               "aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)"))) {
390         LogAndDumpSchema(n);
391         return nullptr;
392       }
393       return [](ProcessedNode* p_node) {
394         const auto& in0_t = p_node->Input(0).toTensor();
395         const auto in1_i = p_node->Input(1).toInt();
396         const auto in2_i = p_node->Input(2).toInt();
397         p_node->Output(0) = at::native::transpose(in0_t, in1_i, in2_i);
398       };
399     });
400 
__anon75e5f0512702(Node* n) 401 REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::flatten, aten_flatten, [](Node* n) -> SROperator {
402   if (!n->matches(torch::schema(
403           "aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)"))) {
404     LogAndDumpSchema(n);
405     return nullptr;
406   }
407   return [](ProcessedNode* p_node) {
408     const auto& in0_t = p_node->Input(0).toTensor();
409     const auto in1_i = p_node->Input(1).toInt();
410     const auto in2_i = p_node->Input(2).toInt();
411     p_node->Output(0) = at::native::flatten(in0_t, in1_i, in2_i);
412   };
413 });
414 
415 REGISTER_NATIVE_OPERATOR_FUNCTOR(
416     aten::permute,
417     aten_permute,
__anon75e5f0512902(Node* n) 418     [](Node* n) -> SROperator {
419       if (!n->matches(torch::schema(
420               "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)"))) {
421         LogAndDumpSchema(n);
422         return nullptr;
423       }
424       return [](ProcessedNode* p_node) {
425         const auto& in0_t = p_node->Input(0).toTensor();
426         const auto in1_iv = p_node->Input(1).toDimVector();
427         p_node->Output(0) = at::native::permute(in0_t, in1_iv);
428       };
429     });
430 
431 REGISTER_NATIVE_OPERATOR_FUNCTOR(
432     aten::reshape,
433     aten_reshape,
__anon75e5f0512b02(Node* n) 434     [](Node* n) -> SROperator {
435       if (!n->matches(torch::schema(
436               "aten::reshape(Tensor(a) self, int[] shape) -> Tensor(a)"))) {
437         LogAndDumpSchema(n);
438         return nullptr;
439       }
440       return [](ProcessedNode* p_node) {
441         const auto& in0_t = p_node->Input(0).toTensor();
442         const auto in1_iv = p_node->Input(1).toDimVector();
443         p_node->Output(0) = at::native::reshape(in0_t, in1_iv);
444       };
445     });
446 
__anon75e5f0512d02(Node* n) 447 REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::slice, aten_slice, [](Node* n) -> SROperator {
448   if (!n->matches(torch::schema(
449           "aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor(a)"))) {
450     LogAndDumpSchema(n);
451     return nullptr;
452   }
453   return [](ProcessedNode* p_node) {
454     const auto& in0_t = p_node->Input(0).toTensor();
455     const auto in1_i = p_node->Input(1).toInt();
456     const auto in2_i = p_node->Input(2).toOptional<int64_t>();
457     const auto in3_i = p_node->Input(3).toOptional<int64_t>();
458     const auto in4_i = p_node->Input(4).toInt();
459     p_node->Output(0) = at::native::slice(in0_t, in1_i, in2_i, in3_i, in4_i);
460   };
461 });
462 
__anon75e5f0512f02(Node* n) 463 REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::narrow, aten_narrow, [](Node* n) -> SROperator {
464   if (!n->matches(torch::schema(
465           "aten::narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)")) &&
466       !n->matches(torch::schema(
467           "aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, int length) -> Tensor(a)"))) {
468     LogAndDumpSchema(n);
469     return nullptr;
470   }
471   return [](ProcessedNode* p_node) {
472     const auto& self = p_node->Input(0).toTensor(); // self
473     const auto dim = p_node->Input(1).toInt(); // dim
474     int64_t start = 0;
475     if (p_node->Input(2).isScalar()) {
476       start = p_node->Input(2).toInt();
477     } else {
478       auto& t = p_node->Input(2).toTensor();
479       start = t.item<int64_t>();
480     }
481     const auto length = p_node->Input(3).toInt(); // length
482     TORCH_CHECK(
483         self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
484     auto cur_size = self.sizes()[dim];
485     if (start != cur_size && start < 0) { // start being the end is valid, but
486                                           // not a valid dim specification.
487       start = at::maybe_wrap_dim(start, cur_size);
488     }
489     TORCH_CHECK(
490         length >= 0 && start <= cur_size - length,
491         "start (",
492         start,
493         ") + length (",
494         length,
495         ") exceeds dimension size (",
496         cur_size,
497         ").");
498     p_node->Output(0) = at::native::slice(self, dim, start, start + length, 1);
499   };
500 });
501 
__anon75e5f0513102(Node* n) 502 REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::to, aten_to, [](Node* n) -> SROperator {
503   if (n->matches(torch::schema(
504           "aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"))) {
505     return [](ProcessedNode* p_node) {
506       const auto& in0_t = p_node->Input(0).toTensor();
507       const auto& in1_t = p_node->Input(1).toTensor();
508       const auto in2_i = p_node->Input(2).toBool();
509       const auto in3_i = p_node->Input(3).toBool();
510       const auto in4_o = p_node->Input(4).toOptional<at::MemoryFormat>();
511       p_node->Output(0) = at::native::to(in0_t, in1_t, in2_i, in3_i, in4_o);
512     };
513   }
514   if (n->matches(torch::schema(
515           "aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"))) {
516     return [](ProcessedNode* p_node) {
517       const auto& in0_t = p_node->Input(0).toTensor();
518       const auto in1_i = p_node->Input(1).toScalarType();
519       const auto in2_i = p_node->Input(2).toBool();
520       const auto in3_i = p_node->Input(3).toBool();
521       const auto in4_o = p_node->Input(4).toOptional<at::MemoryFormat>();
522       p_node->Output(0) = at::native::to(in0_t, in1_i, in2_i, in3_i, in4_o);
523     };
524   }
525   if (n->matches(torch::schema(
526           "aten::to.prim_dtype(Tensor(a) self, int? dtype, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"))) {
527     return [](ProcessedNode* p_node) {
528       const auto& in0_t = p_node->Input(0).toTensor();
529       const auto in1_i = p_node->Input(1).toOptional<at::ScalarType>();
530       const auto in2_i = p_node->Input(2).toBool();
531       const auto in3_i = p_node->Input(3).toBool();
532       // To mimick the behavior of the JIT interpreter, if both dtype
533       // and copy are not set, we return self. Otherwise, we assume
534       // that dtype is set.
535       if (!in1_i && !in3_i) {
536         p_node->Output(0) = in0_t;
537       } else {
538         TORCH_CHECK(
539             in1_i,
540             "dytpe cannot be None when copy is True for aten::to.prim_dtype");
541         p_node->Output(0) = at::native::to(in0_t, *in1_i, in2_i, in3_i);
542       }
543     };
544   }
545   LogAndDumpSchema(n);
546   return nullptr;
547 });
548 
549 REGISTER_NATIVE_OPERATOR_FUNCTOR(
550     aten::detach,
551     aten_detach,
__anon75e5f0513502(Node* n) 552     [](Node* n) -> SROperator {
553       if (!n->matches(
554               torch::schema("aten::detach(Tensor(a) self) -> Tensor(a)"))) {
555         LogAndDumpSchema(n);
556         return nullptr;
557       }
558       return [](ProcessedNode* p_node) {
559         const auto& in0_t = p_node->Input(0).toTensor();
560         p_node->Output(0) = at::native::alias(in0_t);
561       };
562     });
563 
564 REGISTER_NATIVE_OPERATOR_FUNCTOR(
565     aten::expand_as,
566     aten_expand_as,
__anon75e5f0513702(Node* n) 567     [](Node* n) -> SROperator {
568       if (!n->matches(torch::schema(
569               "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)"))) {
570         LogAndDumpSchema(n);
571         return nullptr;
572       }
573       return [](ProcessedNode* p_node) {
574         const auto& self = p_node->Input(0).toTensor();
575         const auto& other = p_node->Input(1).toTensor();
576         p_node->Output(0) = self.expand(other.sizes());
577       };
578     });
579 
580 REGISTER_NATIVE_OPERATOR_FUNCTOR(
581     prim::isinstance,
582     prim_isinstance,
__anon75e5f0513902(Node* n) 583     [](Node* n) -> SROperator {
584       if (!n->matches(
585               torch::schema("prim::isinstance(Any to_check) -> bool"))) {
586         LogAndDumpSchema(n);
587         return nullptr;
588       }
589       return [](ProcessedNode* p_node) {
590         auto input_type = p_node->Input(0).type();
591 
592         auto* node = p_node->node();
593         const std::vector<TypePtr>& candidates = node->tys(attr::types);
594         for (const auto& candidate_type : candidates) {
595           if (input_type->isSubtypeOf(*candidate_type)) {
596             p_node->Output(0) = true;
597             return;
598           }
599         }
600 
601         p_node->Output(0) = false;
602       };
603     });
604 
605 REGISTER_NATIVE_OPERATOR_FUNCTOR(
606     prim::TypeCheck,
607     prim_TypeCheck,
__anon75e5f0513b02(Node* n) 608     [](Node* n) -> SROperator {
609       if (!sr_schema_check_kind(n, prim::TypeCheck)) {
610         return nullptr;
611       }
612       return [](ProcessedNode* p_node) {
613         auto* node = p_node->node();
614         const size_t num_inputs = node->inputs().size();
615         TORCH_INTERNAL_ASSERT(
616             num_inputs && num_inputs + 1 == node->outputs().size());
617 
618         const auto& expected_types = node->tys(attr::types);
619 
620         for (size_t i = 0; i < num_inputs; i++) {
621           p_node->Output(i) = p_node->Input(i);
622         }
623 
624         for (size_t i = 0; i < num_inputs; i++) {
625           auto& input_tensor = p_node->Input(i).toTensor();
626           auto* expected_type = expected_types[i]->castRaw<TensorType>();
627           if (input_tensor.defined() &&
628               !expected_type->matchTensor(input_tensor)) {
629             p_node->Output(num_inputs) = false;
630             return;
631           }
632         }
633 
634         p_node->Output(num_inputs) = true;
635       };
636     });
637 
638 // See [Borrowed IValue Outputs]
639 REGISTER_NATIVE_OPERATOR_FUNCTOR(
640     static_runtime::VarTupleUnpack,
641     static_runtime_VarTupleUnpack,
__anon75e5f0513d02(Node* n) 642     [](Node* n) -> SROperator {
643       if (!sr_schema_check(n, "static_runtime::VarTupleUnpack(...) -> ...")) {
644         return nullptr;
645       }
646       return [](ProcessedNode* pnode) {
647         size_t output_idx = 0;
648         for (const auto idx : c10::irange(pnode->num_inputs())) {
649           const auto& tuple = pnode->Input(idx);
650           for (auto& elem : tuple.toTupleRef().elements()) {
651             pnode->Output(output_idx) = createBorrowedIValue(elem);
652             ++output_idx;
653           }
654         }
655       };
656     });
657 
658 REGISTER_NATIVE_OPERATOR_FUNCTOR(
659     aten::view,
660     aten_view,
__anon75e5f0513f02(Node* n) 661     [](Node* n) -> SROperator {
662       if (!n->matches(torch::schema(
663               "aten::view(Tensor(a) self, int[] size) -> (Tensor(a))"))) {
664         LogAndDumpSchema(n);
665         return nullptr;
666       }
667       return [](ProcessedNode* p_node) {
668         const auto& input = p_node->Input(0).toTensor();
669         const auto size = p_node->Input(1).toIntList();
670         p_node->Output(0) = at::native::view(input, size.vec());
671       };
672     });
673 
674 REGISTER_NATIVE_OPERATOR_FUNCTOR(
675     aten::size,
676     aten_size,
__anon75e5f0514102(Node* n) 677     [](Node* n) -> SROperator {
678       if (n->matches(
679               torch::schema("aten::size(Tensor self, int dim) -> int"))) {
680         return [](ProcessedNode* p_node) {
681           const auto& input = p_node->Input(0).toTensor();
682           auto dim = p_node->Input(1).toInt();
683           const auto ndim = input.dim();
684 
685           if (dim < 0 || dim >= ndim) {
686             dim = c10::maybe_wrap_dim(dim, ndim);
687           }
688           p_node->Output(0) = input.sizes()[dim];
689         };
690       }
691       if (n->matches(torch::schema("aten::size(Tensor self) -> int[]"))) {
692         return [](ProcessedNode* p_node) {
693           const auto& input = p_node->Input(0).toTensor();
694           p_node->Output(0) = input.sizes();
695         };
696       }
697       LogAndDumpSchema(n);
698       return nullptr;
699     });
700 
701 REGISTER_NATIVE_OPERATOR_FUNCTOR(
702     aten::squeeze,
703     aten_squeeze,
__anon75e5f0514402(Node* n) 704     [](Node* n) -> SROperator {
705       if (!n->matches(torch::schema(
706               "aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)"))) {
707         LogAndDumpSchema(n);
708         return nullptr;
709       }
710 
711       return [](ProcessedNode* p_node) {
712         const auto& self = p_node->Input(0).toTensor();
713         const auto dim = p_node->Input(1).toInt();
714         p_node->Output(0) = at::native::squeeze(self, dim);
715       };
716     });
717 
__anon75e5f0514602(Node* n) 718 REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::split, aten_split, [](Node* n) -> SROperator {
719   if (n->matches(torch::schema(
720           "aten::split(Tensor(a -> *) self, int split_size, int dim=0) -> Tensor(a)[]"))) {
721     return [](ProcessedNode* p_node) {
722       const auto& self = p_node->Input(0).toTensor();
723       const auto split_size = p_node->Input(1).toInt();
724       const auto dim = p_node->Input(2).toInt();
725       p_node->Output(0) = at::native::split(self, split_size, dim);
726     };
727   }
728 
729   if (n->matches(torch::schema(
730           "aten::split(Tensor(a -> *) self, int[] split_sizes, int dim=0) -> (Tensor[])"))) {
731     return [](ProcessedNode* p_node) {
732       const auto& self = p_node->Input(0).toTensor();
733       const auto& split_sizes = p_node->Input(1).toIntList();
734       const auto dim = p_node->Input(2).toInt();
735       p_node->Output(0) =
736           at::native::split_with_sizes(self, split_sizes.vec(), dim);
737     };
738   }
739 
740   LogAndDumpSchema(n);
741   return nullptr;
742 });
743 
744 REGISTER_NATIVE_OPERATOR_FUNCTOR(
745     aten::split_with_sizes,
746     aten_split_with_sizes,
__anon75e5f0514902(Node* n) 747     [](Node* n) -> SROperator {
748       if (!n->matches(torch::schema(
749               "aten::split_with_sizes(Tensor(a -> *) self, int[] split_sizes, int dim=0) -> Tensor(a)[]")) &&
750           !n->matches(torch::schema(
751               "aten::split_with_sizes(Tensor(a -> *) self, int[] split_sizes, int dim=0) -> (Tensor[])"))) {
752         LogAndDumpSchema(n);
753         return nullptr;
754       }
755       return [](ProcessedNode* p_node) {
756         const auto& self = p_node->Input(0).toTensor();
757         const auto& split_sizes = p_node->Input(1).toIntList();
758         const auto dim = p_node->Input(2).toInt();
759         p_node->Output(0) =
760             at::native::split_with_sizes(self, split_sizes.vec(), dim);
761       };
762     });
763 
764 REGISTER_NATIVE_OPERATOR_FUNCTOR(
765     static_runtime::select_tensor,
766     aten_select_tensor,
__anon75e5f0514b02(Node* n) 767     [](Node* n) -> SROperator {
768       if (!sr_schema_check(
769               n,
770               "static_runtime::select_tensor(Tensor(a) a, Tensor(b) b, bool use_b) -> Tensor(a|b)")) {
771         return nullptr;
772       }
773       return [](ProcessedNode* p_node) {
774         const auto did_copy = p_node->Input(2).toBool();
775         DCHECK(p_node->Input(0).isTensor());
776         DCHECK(!did_copy || p_node->Input(1).isTensor());
777         const IValue& assignFrom =
778             did_copy ? p_node->Input(1) : p_node->Input(0);
779         // Create an IValue that borrows the input Tensor in order to
780         // save a refcount increment here and decrement in
781         // MemoryPlanner::deallocate. MemoryPlanner knows about this
782         // and will safely clean it up by using the corresponding
783         // destroyBorrow method.
784         TORCH_DCHECK_NE(&assignFrom, &p_node->Output(0));
785         // MemoryPlanner should have cleaned this up!
786         DCHECK(p_node->Output(0).isNone());
787         p_node->Output(0) =
788             IValue(c10::MaybeOwnedTraits<at::TensorBase>::createBorrow(
789                 assignFrom.toTensor()));
790       };
791     });
792 
793 REGISTER_NATIVE_OPERATOR_FUNCTOR(
794     aten::mul,
795     aten_mul,
__anon75e5f0514d02(Node* n) 796     [](Node* n) -> SROperator {
797       if (!n->matches(
798               torch::schema("aten::mul.left_t(t[] l, int n) -> (t[])"))) {
799         LogAndDumpSchema(n);
800         return nullptr;
801       }
802       return [](ProcessedNode* pnode) {
803         const auto& list = pnode->Input(0).toList();
804         const auto n = pnode->Input(1).toInt();
805 
806         auto list_type = list.elementType();
807         auto ret = c10::impl::GenericList(list_type);
808         ret.reserve(list.size() * n);
809         for (const auto i : c10::irange(n)) {
810           (void)i;
811           for (const auto& ival : list) {
812             ret.push_back(ival);
813           }
814         }
815         pnode->Output(0) = ret;
816       };
817     });
818 
819 REGISTER_NATIVE_OPERATOR_FUNCTOR(
820     aten::sub,
821     aten_sub,
__anon75e5f0514f02(Node* n) 822     [](Node* n) -> SROperator {
823       if (!n->matches(torch::schema("aten::sub.int(int a, int b) -> (int)"))) {
824         LogAndDumpSchema(n);
825         return nullptr;
826       }
827       return [](ProcessedNode* pnode) {
828         const auto a = pnode->Input(0).toInt();
829         const auto b = pnode->Input(1).toInt();
830         pnode->Output(0) = a - b;
831       };
832     });
833 
834 REGISTER_NATIVE_OPERATOR_FUNCTOR(
835     aten::add,
836     aten_add,
__anon75e5f0515102(Node* n) 837     [](Node* n) -> SROperator {
838       if (n->matches(torch::schema("aten::add.t(t[] a, t[] b) -> (t[])"))) {
839         return [](ProcessedNode* pnode) {
840           const auto& a = pnode->Input(0).toList();
841           const auto& b = pnode->Input(1).toList();
842           auto ret = a.copy();
843           ret.append(b);
844           pnode->Output(0) = ret;
845         };
846       }
847 
848       if (n->matches(torch::schema("aten::add.int(int a, int b) -> (int)"))) {
849         return [](ProcessedNode* pnode) {
850           const auto a = pnode->Input(0).toInt();
851           const auto b = pnode->Input(1).toInt();
852           pnode->Output(0) = a + b;
853         };
854       }
855 
856       LogAndDumpSchema(n);
857       return nullptr;
858     });
859 
__anon75e5f0515402(Node* n) 860 REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::tensor_split, aten_tensor_split, [](Node* n) -> SROperator {
861   if (n->matches(torch::schema(
862           "aten::tensor_split.indices(Tensor(a -> *) self, int[] indices, int dim=0) -> Tensor(a)[]"))) {
863     return [](ProcessedNode* pnode) {
864       const auto& a = pnode->Input(0).toTensor();
865       const auto& b = pnode->Input(1).toIntVector();
866       const auto c = pnode->Input(2).toInt();
867       pnode->Output(0) = at::native::tensor_split(a, b, c);
868     };
869   }
870 
871   if (n->matches(torch::schema(
872           "aten::tensor_split.sections(Tensor(a -> *) self, int sections, int dim=0) -> Tensor(a)[]"))) {
873     return [](ProcessedNode* pnode) {
874       const auto& a = pnode->Input(0).toTensor();
875       const auto b = pnode->Input(1).toSymInt();
876       const auto c = pnode->Input(2).toInt();
877       pnode->Output(0) = at::native::tensor_split_sections_symint(a, b, c);
878     };
879   }
880 
881   if (n->matches(torch::schema(
882           "aten::tensor_split.tensor_indices_or_sections(Tensor(a -> *) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[]"))) {
883     return [](ProcessedNode* pnode) {
884       const auto& a = pnode->Input(0).toTensor();
885       const auto& b = pnode->Input(1).toTensor();
886       const auto c = pnode->Input(2).toInt();
887       pnode->Output(0) = at::native::tensor_split(a, b, c);
888     };
889   }
890   LogAndDumpSchema(n);
891   return nullptr;
892 });
893 
894 REGISTER_NATIVE_OPERATOR_FUNCTOR(
895     aten::Int,
896     aten_Int,
__anon75e5f0515802(Node* n) 897     [](Node* n) -> SROperator {
898       if (!n->matches(torch::schema("aten::Int(Tensor a) -> int"))) {
899         LogAndDumpSchema(n);
900         return nullptr;
901       }
902       return [](ProcessedNode* pnode) {
903         const auto& input = pnode->Input(0).toTensor();
904         pnode->Output(0) = at::native::item(input).toInt();
905       };
906     });
907 
908 // See [Create owned refs for special values]
909 REGISTER_NATIVE_OPERATOR_FUNCTOR(
910     static_runtime::create_owned_ref,
911     static_runtime_create_owned_ref,
__anon75e5f0515a02(Node* n) 912     [](Node* n) -> SROperator {
913       if (!sr_schema_check(n, "static_runtime::create_owned_ref(...) -> ...")) {
914         return nullptr;
915       }
916       return
917           [](ProcessedNode* p_node) { p_node->Output(0) = p_node->Input(0); };
918     });
919 
920 namespace {
outputsEmpty(const Block * block)921 bool outputsEmpty(const Block* block) {
922   return block->outputs().size() == 1 && block->outputs().at(0)->mustBeNone();
923 }
924 
blockEmpty(const Block * block)925 bool blockEmpty(const Block* block) {
926   return block->nodes().begin() == block->nodes().end();
927 }
928 
929 enum class BlockRunPlan : int8_t {
930   kRunOnlyTrueBlock,
931   kRunOnlyFalseBlock,
932   kRunBothBlocks,
933   kRunNeitherBlock,
934 };
935 } // namespace
936 
937 REGISTER_NATIVE_OPERATOR_FUNCTOR(
938     prim::If,
939     prim_If,
__anon75e5f0515d02(Node* node) 940     [](Node* node) -> SROperator {
941       if (!sr_schema_check_kind(node, prim::If)) {
942         return nullptr;
943       }
944       TORCH_DCHECK_EQ(node->blocks().size(), 2);
945       const Block* true_block = node->blocks().at(0);
946       const Block* false_block = node->blocks().at(1);
947 
948       const bool true_block_returns_empty = outputsEmpty(true_block);
949       const bool false_block_returns_empty = outputsEmpty(false_block);
950 
951       BlockRunPlan block_run_plan = BlockRunPlan::kRunNeitherBlock;
952 
953       if (true_block_returns_empty && false_block_returns_empty) {
954         const bool false_block_is_empty = blockEmpty(false_block);
955         const bool true_block_is_empty = blockEmpty(true_block);
956 
957         if (false_block_is_empty && !true_block_is_empty) {
958           block_run_plan = BlockRunPlan::kRunOnlyTrueBlock;
959         } else if (!false_block_is_empty && true_block_is_empty) {
960           block_run_plan = BlockRunPlan::kRunOnlyFalseBlock;
961         } else if (false_block_is_empty && true_block_is_empty) {
962           block_run_plan = BlockRunPlan::kRunNeitherBlock;
963         } else {
964           block_run_plan = BlockRunPlan::kRunBothBlocks;
965         }
966       } else {
967         block_run_plan = BlockRunPlan::kRunBothBlocks;
968       }
969 
970       switch (block_run_plan) {
971         case BlockRunPlan::kRunBothBlocks:
972           return [](ProcessedNode* p_node) {
973             auto condition = p_node->Input(0).toBool();
974             auto* metadata = p_node->metadata();
975             DCHECK(metadata);
976             auto& block_runners = metadata->block_runners();
977             TORCH_DCHECK_EQ(block_runners.size(), 2);
978             auto& runner = block_runners[!condition];
979 
980             auto output = runner({});
981             // If we are returning a tuple, we are either returning
982             // multiple unpacked values or all of the values wrapped
983             // in a single tuple. The second condition handles the
984             // the latter case.
985             if (!output.isTuple() || p_node->num_outputs() == 1) {
986               p_node->Output(0) = std::move(output);
987               return;
988             }
989             auto& elems = output.toTupleRef().elements();
990             TORCH_DCHECK_EQ(elems.size(), p_node->num_outputs());
991             for (const auto i : c10::irange(elems.size())) {
992               p_node->Output(i) = elems[i];
993             }
994           };
995         case BlockRunPlan::kRunOnlyTrueBlock:
996           return [](ProcessedNode* p_node) {
997             auto condition = p_node->Input(0).toBool();
998             auto* metadata = p_node->metadata();
999             DCHECK(metadata);
1000             auto& block_runners = metadata->block_runners();
1001             TORCH_DCHECK_EQ(block_runners.size(), 2);
1002             if (condition) {
1003               auto output = block_runners.front()({});
1004               DCHECK(output.isNone());
1005             }
1006           };
1007         case BlockRunPlan::kRunOnlyFalseBlock:
1008           return [](ProcessedNode* p_node) {
1009             auto condition = p_node->Input(0).toBool();
1010             auto* metadata = p_node->metadata();
1011             DCHECK(metadata);
1012             auto& block_runners = metadata->block_runners();
1013             TORCH_DCHECK_EQ(block_runners.size(), 2);
1014             if (!condition) {
1015               auto output = block_runners.back()({});
1016               DCHECK(output.isNone());
1017             }
1018           };
1019         case BlockRunPlan::kRunNeitherBlock:
1020           return [](ProcessedNode*) {};
1021       }
1022       return [](ProcessedNode*) {};
1023     });
1024 
1025 namespace {
1026 
collectLoopSubBlockInputs(const ProcessedNode & p_node)1027 std::vector<IValue> collectLoopSubBlockInputs(const ProcessedNode& p_node) {
1028   const auto num_inputs = p_node.num_inputs();
1029   TORCH_DCHECK_GE(num_inputs, 2);
1030   // The first two inputs to the loop node are the max trip count
1031   // and initial condition. We don't collect them here, since those
1032   // are not inputs for the sub-block.
1033   const auto num_args = num_inputs - 2;
1034 
1035   std::vector<IValue> result;
1036   result.reserve(num_args + 1);
1037   // First argument to the loop sub-block is always the loop counter, initially
1038   // zero.
1039   result.emplace_back(0);
1040 
1041   for (const auto i : c10::irange(num_args)) {
1042     result.push_back(p_node.Input(2 + i));
1043   }
1044 
1045   return result;
1046 }
1047 
1048 } // namespace
1049 
1050 namespace {
1051 /*
1052   ForkedSubgraphSRLauncher is responsible for the execution of
1053   forked subgraph on new instance of static runtime. Once the
1054   execution is completed, future is marked as complete to
1055   indicate aten::wait() to proceed
1056 */
1057 class TORCH_API ForkedSubgraphSRLauncher {
1058  public:
ForkedSubgraphSRLauncher(std::shared_ptr<StaticModule> smodule,std::vector<IValue> args,c10::intrusive_ptr<Future> future,TaskLauncher launcher)1059   ForkedSubgraphSRLauncher(
1060       std::shared_ptr<StaticModule> smodule,
1061       std::vector<IValue> args,
1062       c10::intrusive_ptr<Future> future,
1063       TaskLauncher launcher)
1064       : smodule_(std::move(smodule)),
1065         args_(std::move(args)),
1066         future_(std::move(future)),
1067         launcher_(std::move(launcher)) {}
1068 
operator ()()1069   void operator()() {
1070     try {
1071       StaticRuntime runtime(*smodule_);
1072       auto future_subgraph = runtime.runAsync(args_, {}, launcher_);
1073       future_subgraph->waitAndThrow();
1074       future_->markCompleted(future_subgraph->value());
1075     } catch (const std::exception& e) {
1076       future_->setErrorIfNeeded(
1077           std::make_exception_ptr(c10::ivalue::Future::FutureError(e.what())));
1078     }
1079   }
1080 
1081  private:
1082   std::shared_ptr<StaticModule> smodule_;
1083   std::vector<IValue> args_;
1084   c10::intrusive_ptr<Future> future_;
1085   torch::jit::TaskLauncher launcher_;
1086 };
1087 
1088 /*
1089   helper function to create a future on return type
1090   of the graph outputs. This function is utilized by
1091   prim::fork and aten::wait operations for async
1092   execution of subgraphs
1093 */
createFutureTypeFromGraphOutput(const std::shared_ptr<torch::jit::Graph> & graph)1094 c10::intrusive_ptr<Future> createFutureTypeFromGraphOutput(
1095     const std::shared_ptr<torch::jit::Graph>& graph) {
1096   TypePtr return_type_;
1097   if (graph->outputs().size() == 1) {
1098     return_type_ = graph->outputs().at(0)->type();
1099   } else {
1100     return_type_ = TupleType::create(
1101         fmap(graph->outputs(), [](const Value* v) { return v->type(); }));
1102   }
1103   c10::intrusive_ptr<Future> future = c10::make_intrusive<Future>(return_type_);
1104   return future;
1105 }
1106 } // namespace
1107 
1108 /*
1109   prim::fork forks the execution of a subgraph. It returns a future on which
1110   the corresponding aten::wait op waits until future is marked complete
1111   Current implementation creates a instance of StaticModule uses it to
1112   create StaticRuntime instances on the fly during runtime to handle the
1113   execution of forked subgraph. Async execution is handled by
1114   aten::ParallelThreadPoolNative threadpool.
1115 */
1116 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1117     prim::fork,
1118     prim_Fork,
__anon75e5f0516602(Node* node) 1119     [](Node* node) -> SROperator {
1120       if (!sr_schema_check_kind(node, prim::fork)) {
1121         return nullptr;
1122       }
1123       auto forkedGraph = node->g(attr::Subgraph);
1124       Inline(*forkedGraph);
1125       auto sr_metadata = node->ival(getStaticRuntimeMetadataSymbol())
1126                              .toCustomClass<StaticRuntimeMetadata>();
1127       auto smodule =
1128           std::make_shared<StaticModule>(forkedGraph, sr_metadata->get_opts());
1129 
1130       return [forkedGraph = std::move(forkedGraph),
1131               smodule = std::move(smodule)](ProcessedNode* p_node) {
1132         std::vector<IValue> args;
1133         args.reserve(p_node->num_inputs());
1134         for (const auto i : c10::irange(p_node->num_inputs())) {
1135           args.push_back(p_node->Input(i));
1136         }
1137 
1138         c10::intrusive_ptr<Future> future =
1139             createFutureTypeFromGraphOutput(forkedGraph);
1140         p_node->Output(0) = future;
1141 
1142         auto* metadata = p_node->metadata();
1143         DCHECK(metadata);
1144         auto* launcher = metadata->launcher();
1145         DCHECK(launcher);
1146         ForkedSubgraphSRLauncher runtime_launcher(
1147             smodule, args, future, *launcher);
1148         (*launcher)(std::move(runtime_launcher));
1149       };
1150     });
1151 /*
1152   aten::wait waits on the future (present in corresponding fork)
1153   to be executed. Once the execution is complete, the future is marked
1154   completed and wait execution continues.
1155 */
1156 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1157     aten::wait,
1158     aten_Wait,
__anon75e5f0516802(Node* n) 1159     [](Node* n) -> SROperator {
1160       if (!sr_schema_check(n, "aten::wait(Future(t) self) -> t")) {
1161         return nullptr;
1162       }
1163       return [](ProcessedNode* p_node) {
1164         TORCH_INTERNAL_ASSERT(p_node->Input(0).isFuture());
1165         auto future = p_node->Input(0).toFuture();
1166 
1167         // blocking call: waiting for the future to be completed
1168         future->waitAndThrow();
1169 
1170         TORCH_INTERNAL_ASSERT(future->completed());
1171         TORCH_INTERNAL_ASSERT(!future->hasError());
1172         TORCH_INTERNAL_ASSERT(future->hasValue());
1173 
1174         if (!future->value().isTuple()) {
1175           p_node->Output(0) = future->value();
1176           return;
1177         }
1178         auto& elems = future->value().toTupleRef().elements();
1179         TORCH_DCHECK_EQ(elems.size(), p_node->num_outputs());
1180         for (const auto i : c10::irange(elems.size())) {
1181           p_node->Output(i) = elems[i];
1182         }
1183       };
1184     });
1185 
1186 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1187     prim::Loop,
1188     prim_Loop,
__anon75e5f0516a02(Node* n) 1189     [](Node* n) -> SROperator {
1190       if (!sr_schema_check_kind(n, prim::Loop)) {
1191         return nullptr;
1192       }
1193       return [](ProcessedNode* p_node) {
1194         const auto max_trip_count = p_node->Input(0).toInt();
1195         auto condition = p_node->Input(1).toBool();
1196 
1197         auto* metadata = p_node->metadata();
1198         DCHECK(metadata);
1199         auto& block_runners = metadata->block_runners();
1200         TORCH_DCHECK_EQ(block_runners.size(), 1);
1201         auto& runner = block_runners[0];
1202 
1203         auto args = collectLoopSubBlockInputs(*p_node);
1204         int64_t loop_count = 0;
1205 
1206         while (condition && loop_count < max_trip_count) {
1207           auto output = runner(args);
1208 
1209           if (output.isTuple()) {
1210             auto& elems = output.toTupleRef().elements();
1211             DCHECK(elems.size() == args.size());
1212             for (const auto i : c10::irange(1, args.size())) {
1213               args[i] = elems[i];
1214             }
1215             condition = elems[0].toBool();
1216           } else {
1217             condition = output.toBool();
1218           }
1219           args[0] = ++loop_count;
1220         }
1221 
1222         const auto num_outputs = p_node->num_outputs();
1223         TORCH_DCHECK_EQ(args.size(), num_outputs + 1);
1224         for (const auto i : c10::irange(num_outputs)) {
1225           p_node->Output(i) = std::move(args[i + 1]);
1226         }
1227       };
1228     });
1229 
1230 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1231     prim::CreateObject,
1232     prim_CreateObject,
__anon75e5f0516c02(Node* node) 1233     [](Node* node) -> SROperator {
1234       if (!sr_schema_check_kind(node, prim::CreateObject)) {
1235         return nullptr;
1236       }
1237       auto class_type = node->output()->type()->expect<ClassType>();
1238       return [class_type = std::move(class_type)](ProcessedNode* pnode) {
1239         pnode->Output(0) = c10::ivalue::Object::create(
1240             c10::StrongTypePtr(class_type->compilation_unit(), class_type),
1241             class_type->numAttributes());
1242       };
1243     });
1244 
1245 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1246     prim::TupleIndex,
1247     prim_TupleIndex,
__anon75e5f0516e02(Node* n) 1248     [](Node* n) -> SROperator {
1249       if (!sr_schema_check_kind(n, prim::TupleIndex)) {
1250         return nullptr;
1251       }
1252       return [](ProcessedNode* pnode) {
1253         const auto& elems = pnode->Input(0).toTupleRef().elements();
1254         using c10::ssize;
1255         const auto num_elems = ssize(elems);
1256         const auto idx = pnode->Input(1).toInt();
1257         const auto norm_idx = normalizeIndex(idx, num_elems);
1258         if (norm_idx < 0 || norm_idx >= num_elems) {
1259           // Use std::runtime_error instead of c10::Error to be consistent with
1260           // JIT
1261           throw std::out_of_range("Tuple index out of range");
1262         }
1263         pnode->Output(0) = elems[norm_idx];
1264       };
1265     });
1266 
1267 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1268     prim::RaiseException,
1269     prim_RaiseException,
__anon75e5f0517002(Node* n) 1270     [](Node* n) -> SROperator {
1271       if (!sr_schema_check_kind(n, prim::RaiseException)) {
1272         return nullptr;
1273       }
1274       return [](ProcessedNode* pnode) {
1275         const auto& message = pnode->Input(0).toStringRef();
1276         throw std::runtime_error(message);
1277       };
1278     });
1279 
1280 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1281     prim::Uninitialized,
1282     prim_Uninitialized,
__anon75e5f0517202(Node* n) 1283     [](Node* n) -> SROperator {
1284       if (!sr_schema_check_kind(n, prim::Uninitialized)) {
1285         return nullptr;
1286       }
1287       return [](ProcessedNode* pnode) {
1288         pnode->Output(0) = IValue::uninitialized();
1289       };
1290     });
1291 
1292 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1293     aten::format,
1294     aten_format,
__anon75e5f0517402(Node* n) 1295     [](Node* n) -> SROperator {
1296       if (!sr_schema_check(n, "aten::format(str self, ...) -> str")) {
1297         return nullptr;
1298       }
1299       TORCH_CHECK(!n->inputs().empty());
1300       return [](ProcessedNode* pnode) {
1301         const auto num_inputs = pnode->num_inputs();
1302         auto stack = boxInputs(*pnode);
1303         format(stack, num_inputs);
1304         TORCH_DCHECK_EQ(stack.size(), 1);
1305         pnode->Output(0) = std::move(stack[0]);
1306       };
1307     });
1308 
1309 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1310     prim::device,
1311     prim_device,
__anon75e5f0517602(Node* n) 1312     [](Node* n) -> SROperator {
1313       if (!sr_schema_check(n, "prim::device(Tensor a) -> Device")) {
1314         return nullptr;
1315       }
1316       return [](ProcessedNode* pnode) {
1317         const auto& input = pnode->Input(0).toTensor();
1318         pnode->Output(0) = input.device();
1319       };
1320     });
1321 
1322 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1323     prim::dtype,
1324     prim_dtype,
__anon75e5f0517802(Node* n) 1325     [](Node* n) -> SROperator {
1326       if (!sr_schema_check_kind(n, prim::dtype)) {
1327         return nullptr;
1328       }
1329       return [](ProcessedNode* pnode) {
1330         const auto& input = pnode->Input(0).toTensor();
1331         pnode->Output(0) = static_cast<int64_t>(input.scalar_type());
1332       };
1333     });
1334 
1335 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1336     aten::dim,
1337     aten_dim,
__anon75e5f0517a02(Node* n) 1338     [](Node* n) -> SROperator {
1339       if (!sr_schema_check(n, "aten::dim(Tensor self) -> int")) {
1340         return nullptr;
1341       }
1342       return [](ProcessedNode* pnode) {
1343         const auto& input = pnode->Input(0).toTensor();
1344         pnode->Output(0) = input.dim();
1345       };
1346     });
1347 
1348 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1349     aten::__not__,
1350     aten_not,
__anon75e5f0517c02(Node* n) 1351     [](Node* n) -> SROperator {
1352       if (!sr_schema_check(n, "aten::__not__(bool self) -> bool")) {
1353         return nullptr;
1354       }
1355       return [](ProcessedNode* pnode) {
1356         auto input = pnode->Input(0).toBool();
1357         pnode->Output(0) = !input;
1358       };
1359     });
1360 
1361 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1362     aten::Bool,
1363     aten_Bool,
__anon75e5f0517e02(Node* n) 1364     [](Node* n) -> SROperator {
1365       if (n->matches(torch::schema("aten::Bool.Tensor(Tensor a) -> bool"))) {
1366         return [](ProcessedNode* pnode) {
1367           const auto& input = pnode->Input(0).toTensor();
1368           pnode->Output(0) = at::native::is_nonzero(input);
1369         };
1370       }
1371       if (n->matches(torch::schema("aten::Bool.int(int a) -> bool"))) {
1372         return [](ProcessedNode* pnode) {
1373           const auto input = pnode->Input(0).toInt();
1374           pnode->Output(0) = static_cast<bool>(input);
1375         };
1376       }
1377       if (n->matches(torch::schema("aten::Bool.float(float a) -> bool"))) {
1378         return [](ProcessedNode* pnode) {
1379           const auto input = pnode->Input(0).toDouble();
1380           pnode->Output(0) = static_cast<bool>(input);
1381         };
1382       }
1383       LogAndDumpSchema(n);
1384       return nullptr;
1385     });
1386 
1387 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1388     prim::is_cuda,
1389     prim_is_cuda,
__anon75e5f0518202(Node* n) 1390     [](Node* n) -> SROperator {
1391       if (!sr_schema_check(n, "prim::is_cuda(Tensor a) -> bool")) {
1392         return nullptr;
1393       }
1394       return [](ProcessedNode* pnode) {
1395         const auto& input = pnode->Input(0).toTensor();
1396         pnode->Output(0) = input.is_cuda();
1397       };
1398     });
1399 
1400 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1401     prim::tolist,
1402     prim_tolist,
__anon75e5f0518402(Node* n) 1403     [](Node* n) -> SROperator {
1404       if (!sr_schema_check_kind(n, prim::tolist)) {
1405         return nullptr;
1406       }
1407       return [](ProcessedNode* pnode) {
1408         const auto& input = pnode->Input(0).toTensor();
1409         const auto dim = pnode->Input(1).toInt();
1410         const auto elem_type = pnode->Input(2).toInt();
1411         std::vector<IValue> stack{input, dim, elem_type};
1412         toList(stack);
1413         TORCH_DCHECK_EQ(stack.size(), 1);
1414         pnode->Output(0) = std::move(stack[0]);
1415       };
1416     });
1417 
1418 // See [Borrowed IValue Outputs]
1419 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1420     prim::IfThenElse,
1421     prim_IfThenElse,
__anon75e5f0518602(Node* n) 1422     [](Node* n) -> SROperator {
1423       if (!sr_schema_check_kind(n, prim::IfThenElse)) {
1424         return nullptr;
1425       }
1426       return [](ProcessedNode* pnode) {
1427         const auto condition = pnode->Input(0).toBool();
1428         pnode->Output(0) = condition ? createBorrowedIValue(pnode->Input(1))
1429                                      : createBorrowedIValue(pnode->Input(2));
1430       };
1431     });
1432 
1433 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1434     aten::len,
1435     aten_len,
__anon75e5f0518802(Node* n) 1436     [](Node* n) -> SROperator {
1437       if (n->matches(torch::schema("aten::len.t(t[] a) -> int")) ||
1438           n->matches(torch::schema("aten::len.any(Any[] a) -> int"))) {
1439         return [](ProcessedNode* pnode) {
1440           const auto list = pnode->Input(0).toListRef();
1441           const int64_t size = list.size();
1442           pnode->Output(0) = size;
1443         };
1444       }
1445       if (n->matches(torch::schema("aten::len.Tensor(Tensor t) -> int"))) {
1446         return [](ProcessedNode* pnode) {
1447           const auto& t = pnode->Input(0).toTensor();
1448           TORCH_CHECK(t.dim() > 0);
1449           pnode->Output(0) = t.sizes()[0];
1450         };
1451       }
1452       if (n->matches(torch::schema("aten::len.str(str s) -> int"))) {
1453         return [](ProcessedNode* pnode) {
1454           const auto& string = pnode->Input(0).toStringRef();
1455           pnode->Output(0) = static_cast<int64_t>(string.size());
1456         };
1457       }
1458       if (n->matches(
1459               torch::schema("aten::len.Dict_str(Dict(str, t) self) -> int")) ||
1460           n->matches(
1461               torch::schema("aten::len.Dict_int(Dict(int, t) self) -> int")) ||
1462           n->matches(torch::schema(
1463               "aten::len.Dict_bool(Dict(bool, t) self) -> int")) ||
1464           n->matches(torch::schema(
1465               "aten::len.Dict_float(Dict(float, t) self) -> int")) ||
1466           n->matches(torch::schema(
1467               "aten::len.Dict_complex(Dict(complex, t) self) -> int")) ||
1468           n->matches(torch::schema(
1469               "aten::len.Dict_Tensor(Dict(Tensor, t) self) -> int"))) {
1470         return [](ProcessedNode* pnode) {
1471           const auto& dict = pnode->Input(0).toGenericDict();
1472           pnode->Output(0) = static_cast<int64_t>(dict.size());
1473         };
1474       }
1475       LogAndDumpSchema(n);
1476       return nullptr;
1477     });
1478 
1479 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1480     aten::IntImplicit,
1481     aten_IntImplicit,
__anon75e5f0518d02(Node* n) 1482     [](Node* n) -> SROperator {
1483       if (!n->matches(torch::schema("aten::IntImplicit(Tensor a) -> int"))) {
1484         LogAndDumpSchema(n);
1485         return nullptr;
1486       }
1487       return [](ProcessedNode* pnode) {
1488         const auto& tensor = pnode->Input(0).toTensor();
1489         // JIT does a check for requires_grad, but we skip it here since SR is
1490         // inference only
1491         if (!tensor.sizes().empty()) {
1492           throw std::runtime_error(
1493               "Cannot convert a tensor of dimension > 0 to scalar");
1494         }
1495         if (!isIntegralType(tensor.scalar_type(), /*includeBool=*/false)) {
1496           std::stringstream ss;
1497           ss << "Cannot input a tensor of type " << tensor.scalar_type()
1498              << " as an integral argument";
1499           throw std::runtime_error(ss.str());
1500         }
1501         pnode->Output(0) = at::native::item(tensor).toInt();
1502       };
1503     });
1504 
1505 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1506     aten::select,
1507     aten_select,
__anon75e5f0518f02(Node* n) 1508     [](Node* n) -> SROperator {
1509       if (!n->matches(torch::schema(
1510               "aten::select(Tensor(a) self, int dim, int index) -> Tensor(a)"))) {
1511         LogAndDumpSchema(n);
1512         return nullptr;
1513       }
1514       return [](ProcessedNode* pnode) {
1515         const auto& self = pnode->Input(0).toTensor();
1516         const auto dim = pnode->Input(1).toInt();
1517         const auto index = pnode->Input(2).toInt();
1518         pnode->Output(0) = at::native::select(self, dim, index);
1519       };
1520     });
1521 
1522 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1523     aten::reshape_as,
1524     aten_reshape_as,
__anon75e5f0519102(Node* n) 1525     [](Node* n) -> SROperator {
1526       if (!n->matches(torch::schema(
1527               "aten::reshape_as(Tensor(a) self, Tensor other) -> Tensor(a)"))) {
1528         LogAndDumpSchema(n);
1529         return nullptr;
1530       }
1531       return [](ProcessedNode* pnode) {
1532         const auto& self = pnode->Input(0).toTensor();
1533         const auto& other = pnode->Input(1).toTensor();
1534         pnode->Output(0) = at::native::reshape(self, other.sizes());
1535       };
1536     });
1537 
1538 } // namespace torch::jit
1539