1 #include <torch/csrc/jit/passes/inliner.h>
2 #include <torch/csrc/jit/runtime/static/impl.h>
3 #include <torch/csrc/jit/runtime/static/ops.h>
4
5 #include <ATen/CPUFunctions.h>
6 #include <ATen/NativeFunctions.h>
7 #include <ATen/ScalarOps.h>
8 #include <ATen/TensorUtils.h>
9 #include <ATen/native/IndexingUtils.h>
10 #include <ATen/native/NonSymbolicBC.h>
11 #include <ATen/native/Resize.h>
12 #include <ATen/native/TensorAdvancedIndexing.h>
13 #include <c10/util/intrusive_ptr.h>
14 #include <c10/util/irange.h>
15 #include <c10/util/ssize.h>
16 #include <torch/csrc/jit/ir/ir.h>
17 #include <torch/csrc/jit/mobile/promoted_prim_ops.h>
18 #include <torch/csrc/jit/runtime/register_ops_utils.h>
19 #include <torch/csrc/jit/runtime/vararg_functions.h>
20
21 namespace {
22 constexpr auto createBorrowedIValue =
23 c10::MaybeOwnedTraits<c10::IValue>::createBorrow;
24 } // namespace
25 namespace torch::jit {
26
27 namespace {
28
boxInputs(const ProcessedNode & pnode)29 std::vector<IValue> boxInputs(const ProcessedNode& pnode) {
30 std::vector<IValue> result;
31 for (const auto i : c10::irange(pnode.num_inputs())) {
32 result.push_back(pnode.Input(i));
33 }
34 return result;
35 }
36
37 } // namespace
38
39 C10_DEFINE_REGISTRY(SRNativeOperatorRegistry, SROperatorFunctor);
40
nativeOpIsRegistered(const c10::Symbol & op_name)41 bool nativeOpIsRegistered(const c10::Symbol& op_name) {
42 const std::string name(op_name.toQualString());
43 return SRNativeOperatorRegistry()->Has(name);
44 }
45
getNativeOperation(Node * n)46 SROperator getNativeOperation(Node* n) {
47 auto op_name = n->kind().toQualString();
48 if (SRNativeOperatorRegistry()->Has(op_name)) {
49 return SRNativeOperatorRegistry()->Create(op_name)->Generate(n);
50 }
51 return nullptr;
52 }
53
54 REGISTER_NATIVE_OPERATOR_FUNCTOR(
55 prim::TupleConstruct,
56 prim_TupleConstruct,
__anon75e5f0510302(Node* n) 57 [](Node* n) -> SROperator {
58 if (!sr_schema_check_kind(n, prim::TupleConstruct)) {
59 return nullptr;
60 }
61 return [](ProcessedNode* p_node) {
62 // prepare inputs
63 auto stack = boxInputs(*p_node);
64 // run op
65 auto* node = p_node->node();
66 const auto& type = node->output()->type()->expect<TupleType>();
67 if (type->name().has_value()) {
68 namedTupleConstruct(stack, type, node->inputs().size());
69 } else {
70 tupleConstruct(stack, node->inputs().size());
71 }
72 // put output back
73 p_node->Output(0) = std::move(stack[0]);
74 };
75 });
76
77 REGISTER_NATIVE_OPERATOR_FUNCTOR(
78 prim::TupleUnpack,
79 prim_TupleUnpack,
__anon75e5f0510502(Node* n) 80 [](Node* n) -> SROperator {
81 if (!sr_schema_check_kind(n, prim::TupleUnpack)) {
82 return nullptr;
83 }
84 return [](ProcessedNode* p_node) {
85 const auto& elems = p_node->Input(0).toTupleRef().elements();
86 const size_t num_outputs = p_node->outputs().size();
87 TORCH_CHECK(
88 num_outputs == elems.size(),
89 "Number of outputs must match number of tuple elements.")
90 for (size_t i = 0; i < num_outputs; ++i) {
91 p_node->Output(i) = elems[i];
92 }
93 };
94 });
95
96 REGISTER_NATIVE_OPERATOR_FUNCTOR(
97 prim::DictConstruct,
98 prim_DictConstruct,
__anon75e5f0510702(Node* n) 99 [](Node* n) -> SROperator {
100 if (!sr_schema_check_kind(n, prim::DictConstruct)) {
101 return nullptr;
102 }
103 auto dict_type = n->output()->type()->expect<DictType>();
104 const auto num_inputs = n->inputs().size();
105 TORCH_DCHECK_EQ(num_inputs % 2, 0);
106 return [dict_type = std::move(dict_type),
107 num_inputs,
108 dict_size = num_inputs / 2](ProcessedNode* p_node) {
109 auto result = c10::impl::GenericDict(
110 dict_type->containedType(0), dict_type->containedType(1));
111 result.reserve(dict_size);
112 for (size_t i = 0; i < num_inputs; i += 2) {
113 const auto& key = p_node->Input(i);
114 const auto& value = p_node->Input(i + 1);
115 result.insert_or_assign(key, value);
116 }
117 p_node->Output(0) = result;
118 };
119 });
120
121 // See [Borrowed IValue Outputs]
122 REGISTER_NATIVE_OPERATOR_FUNCTOR(
123 static_runtime::dict_unpack,
124 static_runtime_dict_unpack,
__anon75e5f0510902(Node* n) 125 [](Node* n) -> SROperator {
126 if (!sr_schema_check(n, "static_runtime::dict_unpack(...) -> ...")) {
127 return nullptr;
128 }
129 return [](ProcessedNode* p_node) {
130 DCHECK(
131 static_cast<size_t>(p_node->num_inputs() - 1) ==
132 p_node->outputs().size());
133 auto dict = p_node->Input(0).toGenericDict();
134 const auto num_inputs = p_node->num_inputs();
135 for (size_t i = 1; i < num_inputs; ++i) {
136 const auto& key = p_node->Input(i);
137 auto value = dict.find(key);
138 TORCH_CHECK(value != dict.end(), "Key not in dict: ", key);
139 p_node->Output(i - 1) = createBorrowedIValue(value->value());
140 }
141 };
142 });
143
__anon75e5f0510b02(Node* n) 144 REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::__getitem__, aten_getitem, [](Node* n) -> SROperator {
145 if (!sr_schema_check(
146 n,
147 // TODO: "aten::__getitem__.str(str s, int index) -> str",
148 "aten::__getitem__.t(t[](a) list, int idx) -> t(*)",
149 "aten::__getitem__.Dict_str(Dict(str, t) self, str key) -> t(*)",
150 "aten::__getitem__.Dict_int(Dict(int, t) self, int key) -> t(*)",
151 "aten::__getitem__.Dict_bool(Dict(bool, t) self, bool key) -> t(*)",
152 "aten::__getitem__.Dict_float(Dict(float, t) self, float key) -> t(*)",
153 "aten::__getitem__.Dict_complex(Dict(complex, t) self, complex key) -> t(*)",
154 "aten::__getitem__.Dict_Tensor(Dict(Tensor, t) self, Tensor key) -> t(*)")) {
155 return nullptr;
156 }
157
158 if (n->inputs().size() != 2) {
159 return nullptr;
160 }
161
162 if (n->input(0)->type()->castRaw<DictType>()) {
163 return [](ProcessedNode* p_node) {
164 auto dict = p_node->Input(0).toGenericDict();
165 const auto& key = p_node->Input(1);
166 auto value = dict.find(key);
167 TORCH_CHECK(value != dict.end(), "Key not in dict: ", key);
168 p_node->Output(0) = value->value();
169 };
170 } else if (n->input(0)->type()->castRaw<ListType>()) {
171 return [](ProcessedNode* p_node) {
172 const auto& list = p_node->Input(0).toList();
173 auto idx = p_node->Input(1).toInt();
174 p_node->Output(0) = getItem(list, idx);
175 };
176 }
177
178 // TODO(T98581096): make __getitem__ work for other container types
179 return nullptr;
180 });
181
182 REGISTER_NATIVE_OPERATOR_FUNCTOR(
183 prim::ListConstruct,
184 prim_ListConstruct,
__anon75e5f0510e02(Node* n) 185 [](Node* n) -> SROperator {
186 if (!sr_schema_check_kind(n, prim::ListConstruct)) {
187 return nullptr;
188 }
189 return [](ProcessedNode* p_node) {
190 // prepare inputs
191 auto stack = boxInputs(*p_node);
192 // run op
193 listConstruct(
194 stack,
195 p_node->node()->output()->type()->expectRef<ListType>(),
196 p_node->num_inputs());
197 // put output back
198 p_node->Output(0) = std::move(stack[0]);
199 };
200 });
201
202 REGISTER_NATIVE_OPERATOR_FUNCTOR(
203 prim::ListUnpack,
204 prim_ListUnpack,
__anon75e5f0511002(Node* n) 205 [](Node* n) -> SROperator {
206 if (!sr_schema_check_kind(n, prim::ListUnpack)) {
207 return nullptr;
208 }
209 const auto num_outputs = n->outputs().size();
210 return [num_outputs](ProcessedNode* p_node) {
211 const auto list = p_node->Input(0).toListRef();
212 TORCH_CHECK(
213 list.size() == num_outputs,
214 "Expected ",
215 num_outputs,
216 " elements in list but got ",
217 list.size());
218 for (const auto i : c10::irange(num_outputs)) {
219 p_node->Output(i) = list[i];
220 }
221 };
222 });
223
224 REGISTER_NATIVE_OPERATOR_FUNCTOR(
225 aten::append,
226 aten_append,
__anon75e5f0511202(Node* n) 227 [](Node* n) -> SROperator {
228 if (!sr_schema_check(
229 n, "aten::append.t(t[](a!) self, t(c -> *) el) -> t[](a!)")) {
230 return nullptr;
231 }
232 return [](ProcessedNode* p_node) {
233 auto list = p_node->Input(0).toList();
234 list.push_back(p_node->Input(1));
235 };
236 });
237
238 REGISTER_NATIVE_OPERATOR_FUNCTOR(
239 aten::list,
240 aten_list,
__anon75e5f0511402(Node* n) 241 [](Node* n) -> SROperator {
242 if (n->matches(torch::schema("aten::list(str t) -> str[]"))) {
243 return [](ProcessedNode* p_node) {
244 const auto str = p_node->Input(0).toStringRef();
245 c10::List<std::string> chars;
246 chars.reserve(str.size());
247 for (auto c : str) {
248 chars.emplace_back(1, c);
249 }
250 p_node->Output(0) = std::move(chars);
251 };
252 }
253
254 if (n->matches(torch::schema("aten::list.t(t[] l) -> t[]"))) {
255 return [](ProcessedNode* p_node) {
256 const auto input = p_node->Input(0).toList();
257 p_node->Output(0) = input.copy();
258 };
259 }
260
261 LogAndDumpSchema(n);
262 return nullptr;
263 });
264
265 REGISTER_NATIVE_OPERATOR_FUNCTOR(
266 aten::numel,
267 aten_numel,
__anon75e5f0511702(Node* n) 268 [](Node* n) -> SROperator {
269 if (!sr_schema_check(n, "aten::numel(Tensor self) -> int")) {
270 return nullptr;
271 }
272 return [](ProcessedNode* p_node) {
273 const auto& arg = p_node->Input(0).toTensor();
274 p_node->Output(0) = arg.numel();
275 };
276 });
277
278 REGISTER_NATIVE_OPERATOR_FUNCTOR(
279 aten::cpu,
280 aten_cpu,
__anon75e5f0511902(Node* n) 281 [](Node* n) -> SROperator {
282 if (!sr_schema_check(n, "aten::cpu(Tensor self) -> Tensor")) {
283 return nullptr;
284 }
285 return [](ProcessedNode* p_node) {
286 const auto& arg = p_node->Input(0).toTensor();
287 p_node->Output(0) = arg.cpu();
288 };
289 });
290
291 REGISTER_NATIVE_OPERATOR_FUNCTOR(
292 aten::__range_length,
293 aten_range_length,
__anon75e5f0511b02(Node* n) 294 [](Node* n) -> SROperator {
295 if (!sr_schema_check(
296 n, "aten::__range_length(int lo, int hi, int step) -> int")) {
297 return nullptr;
298 }
299 return [](ProcessedNode* p_node) {
300 auto lo = p_node->Input(0).toInt();
301 auto hi = p_node->Input(1).toInt();
302 auto step = p_node->Input(2).toInt();
303 // error handling when step_val == 0 during runtime
304 if (step == 0) {
305 throw std::runtime_error("range() arg 3 must not be zero");
306 }
307 if (step > 0 && lo < hi) {
308 p_node->Output(0) = 1 + (hi - 1 - lo) / step;
309 } else if (step < 0 && lo > hi) {
310 p_node->Output(0) = 1 + (lo - 1 - hi) / (0 - step);
311 } else {
312 p_node->Output(0) = 0;
313 }
314 };
315 });
316
__anon75e5f0511d02(Node* n) 317 REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::index_put, aten_index_put, [](Node* n) -> SROperator {
318 if (n->matches(torch::schema(
319 "aten::index_put(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor")) ||
320 n->matches(torch::schema(
321 "aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"))) {
322 return [](ProcessedNode* p_node) {
323 const auto& self = p_node->Input(0).toTensor();
324 const auto& indices =
325 at::native::toListOfOptionalTensors(p_node->Input(1).toListRef());
326 const auto& values = p_node->Input(2).toTensor();
327 const auto accumulate = p_node->Input(3).toBool();
328 p_node->Output(0) =
329 at::native::index_put(self, indices, values, accumulate);
330 };
331 }
332
333 LogAndDumpSchema(n);
334 return nullptr;
335 });
336
337 REGISTER_NATIVE_OPERATOR_FUNCTOR(
338 aten::item,
339 aten_item,
__anon75e5f0511f02(Node* n) 340 [](Node* n) -> SROperator {
341 if (!sr_schema_check(n, "aten::item(Tensor self) -> Scalar")) {
342 return nullptr;
343 }
344 return [](ProcessedNode* p_node) {
345 const auto& self = p_node->Input(0).toTensor();
346 p_node->Output(0) = at::native::item(self);
347 };
348 });
349
350 REGISTER_NATIVE_OPERATOR_FUNCTOR(
351 prim::GetAttr,
352 prim_GetAttr,
__anon75e5f0512102(Node* n) 353 [](Node* n) -> SROperator {
354 if (!sr_schema_check_kind(n, prim::GetAttr)) {
355 return nullptr;
356 }
357 return [](ProcessedNode* p_node) {
358 auto& module = p_node->Input(0).toObjectRef();
359 Node* node = p_node->node();
360 const auto& type = node->input()->type()->expectRef<ClassType>();
361 const auto& field = node->s(attr::name);
362 const auto slot = type.getAttributeSlot(field);
363 p_node->Output(0) = module.getSlot(slot);
364 };
365 });
366
367 REGISTER_NATIVE_OPERATOR_FUNCTOR(
368 prim::SetAttr,
369 prim_SetAttr,
__anon75e5f0512302(Node* n) 370 [](Node* n) -> SROperator {
371 if (!sr_schema_check_kind(n, prim::SetAttr)) {
372 return nullptr;
373 }
374 return [](ProcessedNode* p_node) {
375 auto& module = p_node->Input(0).toObjectRef();
376 Node* node = p_node->node();
377 const auto& type = node->inputs()[0]->type()->expectRef<ClassType>();
378 const auto& field = node->s(attr::name);
379 const auto slot = type.getAttributeSlot(field);
380 module.setSlot(slot, p_node->Input(1));
381 };
382 });
383
384 REGISTER_NATIVE_OPERATOR_FUNCTOR(
385 aten::transpose,
386 aten_transpose,
__anon75e5f0512502(Node* n) 387 [](Node* n) -> SROperator {
388 if (!n->matches(torch::schema(
389 "aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)"))) {
390 LogAndDumpSchema(n);
391 return nullptr;
392 }
393 return [](ProcessedNode* p_node) {
394 const auto& in0_t = p_node->Input(0).toTensor();
395 const auto in1_i = p_node->Input(1).toInt();
396 const auto in2_i = p_node->Input(2).toInt();
397 p_node->Output(0) = at::native::transpose(in0_t, in1_i, in2_i);
398 };
399 });
400
__anon75e5f0512702(Node* n) 401 REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::flatten, aten_flatten, [](Node* n) -> SROperator {
402 if (!n->matches(torch::schema(
403 "aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)"))) {
404 LogAndDumpSchema(n);
405 return nullptr;
406 }
407 return [](ProcessedNode* p_node) {
408 const auto& in0_t = p_node->Input(0).toTensor();
409 const auto in1_i = p_node->Input(1).toInt();
410 const auto in2_i = p_node->Input(2).toInt();
411 p_node->Output(0) = at::native::flatten(in0_t, in1_i, in2_i);
412 };
413 });
414
415 REGISTER_NATIVE_OPERATOR_FUNCTOR(
416 aten::permute,
417 aten_permute,
__anon75e5f0512902(Node* n) 418 [](Node* n) -> SROperator {
419 if (!n->matches(torch::schema(
420 "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)"))) {
421 LogAndDumpSchema(n);
422 return nullptr;
423 }
424 return [](ProcessedNode* p_node) {
425 const auto& in0_t = p_node->Input(0).toTensor();
426 const auto in1_iv = p_node->Input(1).toDimVector();
427 p_node->Output(0) = at::native::permute(in0_t, in1_iv);
428 };
429 });
430
431 REGISTER_NATIVE_OPERATOR_FUNCTOR(
432 aten::reshape,
433 aten_reshape,
__anon75e5f0512b02(Node* n) 434 [](Node* n) -> SROperator {
435 if (!n->matches(torch::schema(
436 "aten::reshape(Tensor(a) self, int[] shape) -> Tensor(a)"))) {
437 LogAndDumpSchema(n);
438 return nullptr;
439 }
440 return [](ProcessedNode* p_node) {
441 const auto& in0_t = p_node->Input(0).toTensor();
442 const auto in1_iv = p_node->Input(1).toDimVector();
443 p_node->Output(0) = at::native::reshape(in0_t, in1_iv);
444 };
445 });
446
__anon75e5f0512d02(Node* n) 447 REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::slice, aten_slice, [](Node* n) -> SROperator {
448 if (!n->matches(torch::schema(
449 "aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor(a)"))) {
450 LogAndDumpSchema(n);
451 return nullptr;
452 }
453 return [](ProcessedNode* p_node) {
454 const auto& in0_t = p_node->Input(0).toTensor();
455 const auto in1_i = p_node->Input(1).toInt();
456 const auto in2_i = p_node->Input(2).toOptional<int64_t>();
457 const auto in3_i = p_node->Input(3).toOptional<int64_t>();
458 const auto in4_i = p_node->Input(4).toInt();
459 p_node->Output(0) = at::native::slice(in0_t, in1_i, in2_i, in3_i, in4_i);
460 };
461 });
462
__anon75e5f0512f02(Node* n) 463 REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::narrow, aten_narrow, [](Node* n) -> SROperator {
464 if (!n->matches(torch::schema(
465 "aten::narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)")) &&
466 !n->matches(torch::schema(
467 "aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, int length) -> Tensor(a)"))) {
468 LogAndDumpSchema(n);
469 return nullptr;
470 }
471 return [](ProcessedNode* p_node) {
472 const auto& self = p_node->Input(0).toTensor(); // self
473 const auto dim = p_node->Input(1).toInt(); // dim
474 int64_t start = 0;
475 if (p_node->Input(2).isScalar()) {
476 start = p_node->Input(2).toInt();
477 } else {
478 auto& t = p_node->Input(2).toTensor();
479 start = t.item<int64_t>();
480 }
481 const auto length = p_node->Input(3).toInt(); // length
482 TORCH_CHECK(
483 self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
484 auto cur_size = self.sizes()[dim];
485 if (start != cur_size && start < 0) { // start being the end is valid, but
486 // not a valid dim specification.
487 start = at::maybe_wrap_dim(start, cur_size);
488 }
489 TORCH_CHECK(
490 length >= 0 && start <= cur_size - length,
491 "start (",
492 start,
493 ") + length (",
494 length,
495 ") exceeds dimension size (",
496 cur_size,
497 ").");
498 p_node->Output(0) = at::native::slice(self, dim, start, start + length, 1);
499 };
500 });
501
__anon75e5f0513102(Node* n) 502 REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::to, aten_to, [](Node* n) -> SROperator {
503 if (n->matches(torch::schema(
504 "aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"))) {
505 return [](ProcessedNode* p_node) {
506 const auto& in0_t = p_node->Input(0).toTensor();
507 const auto& in1_t = p_node->Input(1).toTensor();
508 const auto in2_i = p_node->Input(2).toBool();
509 const auto in3_i = p_node->Input(3).toBool();
510 const auto in4_o = p_node->Input(4).toOptional<at::MemoryFormat>();
511 p_node->Output(0) = at::native::to(in0_t, in1_t, in2_i, in3_i, in4_o);
512 };
513 }
514 if (n->matches(torch::schema(
515 "aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"))) {
516 return [](ProcessedNode* p_node) {
517 const auto& in0_t = p_node->Input(0).toTensor();
518 const auto in1_i = p_node->Input(1).toScalarType();
519 const auto in2_i = p_node->Input(2).toBool();
520 const auto in3_i = p_node->Input(3).toBool();
521 const auto in4_o = p_node->Input(4).toOptional<at::MemoryFormat>();
522 p_node->Output(0) = at::native::to(in0_t, in1_i, in2_i, in3_i, in4_o);
523 };
524 }
525 if (n->matches(torch::schema(
526 "aten::to.prim_dtype(Tensor(a) self, int? dtype, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"))) {
527 return [](ProcessedNode* p_node) {
528 const auto& in0_t = p_node->Input(0).toTensor();
529 const auto in1_i = p_node->Input(1).toOptional<at::ScalarType>();
530 const auto in2_i = p_node->Input(2).toBool();
531 const auto in3_i = p_node->Input(3).toBool();
532 // To mimick the behavior of the JIT interpreter, if both dtype
533 // and copy are not set, we return self. Otherwise, we assume
534 // that dtype is set.
535 if (!in1_i && !in3_i) {
536 p_node->Output(0) = in0_t;
537 } else {
538 TORCH_CHECK(
539 in1_i,
540 "dytpe cannot be None when copy is True for aten::to.prim_dtype");
541 p_node->Output(0) = at::native::to(in0_t, *in1_i, in2_i, in3_i);
542 }
543 };
544 }
545 LogAndDumpSchema(n);
546 return nullptr;
547 });
548
549 REGISTER_NATIVE_OPERATOR_FUNCTOR(
550 aten::detach,
551 aten_detach,
__anon75e5f0513502(Node* n) 552 [](Node* n) -> SROperator {
553 if (!n->matches(
554 torch::schema("aten::detach(Tensor(a) self) -> Tensor(a)"))) {
555 LogAndDumpSchema(n);
556 return nullptr;
557 }
558 return [](ProcessedNode* p_node) {
559 const auto& in0_t = p_node->Input(0).toTensor();
560 p_node->Output(0) = at::native::alias(in0_t);
561 };
562 });
563
564 REGISTER_NATIVE_OPERATOR_FUNCTOR(
565 aten::expand_as,
566 aten_expand_as,
__anon75e5f0513702(Node* n) 567 [](Node* n) -> SROperator {
568 if (!n->matches(torch::schema(
569 "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)"))) {
570 LogAndDumpSchema(n);
571 return nullptr;
572 }
573 return [](ProcessedNode* p_node) {
574 const auto& self = p_node->Input(0).toTensor();
575 const auto& other = p_node->Input(1).toTensor();
576 p_node->Output(0) = self.expand(other.sizes());
577 };
578 });
579
580 REGISTER_NATIVE_OPERATOR_FUNCTOR(
581 prim::isinstance,
582 prim_isinstance,
__anon75e5f0513902(Node* n) 583 [](Node* n) -> SROperator {
584 if (!n->matches(
585 torch::schema("prim::isinstance(Any to_check) -> bool"))) {
586 LogAndDumpSchema(n);
587 return nullptr;
588 }
589 return [](ProcessedNode* p_node) {
590 auto input_type = p_node->Input(0).type();
591
592 auto* node = p_node->node();
593 const std::vector<TypePtr>& candidates = node->tys(attr::types);
594 for (const auto& candidate_type : candidates) {
595 if (input_type->isSubtypeOf(*candidate_type)) {
596 p_node->Output(0) = true;
597 return;
598 }
599 }
600
601 p_node->Output(0) = false;
602 };
603 });
604
605 REGISTER_NATIVE_OPERATOR_FUNCTOR(
606 prim::TypeCheck,
607 prim_TypeCheck,
__anon75e5f0513b02(Node* n) 608 [](Node* n) -> SROperator {
609 if (!sr_schema_check_kind(n, prim::TypeCheck)) {
610 return nullptr;
611 }
612 return [](ProcessedNode* p_node) {
613 auto* node = p_node->node();
614 const size_t num_inputs = node->inputs().size();
615 TORCH_INTERNAL_ASSERT(
616 num_inputs && num_inputs + 1 == node->outputs().size());
617
618 const auto& expected_types = node->tys(attr::types);
619
620 for (size_t i = 0; i < num_inputs; i++) {
621 p_node->Output(i) = p_node->Input(i);
622 }
623
624 for (size_t i = 0; i < num_inputs; i++) {
625 auto& input_tensor = p_node->Input(i).toTensor();
626 auto* expected_type = expected_types[i]->castRaw<TensorType>();
627 if (input_tensor.defined() &&
628 !expected_type->matchTensor(input_tensor)) {
629 p_node->Output(num_inputs) = false;
630 return;
631 }
632 }
633
634 p_node->Output(num_inputs) = true;
635 };
636 });
637
638 // See [Borrowed IValue Outputs]
639 REGISTER_NATIVE_OPERATOR_FUNCTOR(
640 static_runtime::VarTupleUnpack,
641 static_runtime_VarTupleUnpack,
__anon75e5f0513d02(Node* n) 642 [](Node* n) -> SROperator {
643 if (!sr_schema_check(n, "static_runtime::VarTupleUnpack(...) -> ...")) {
644 return nullptr;
645 }
646 return [](ProcessedNode* pnode) {
647 size_t output_idx = 0;
648 for (const auto idx : c10::irange(pnode->num_inputs())) {
649 const auto& tuple = pnode->Input(idx);
650 for (auto& elem : tuple.toTupleRef().elements()) {
651 pnode->Output(output_idx) = createBorrowedIValue(elem);
652 ++output_idx;
653 }
654 }
655 };
656 });
657
658 REGISTER_NATIVE_OPERATOR_FUNCTOR(
659 aten::view,
660 aten_view,
__anon75e5f0513f02(Node* n) 661 [](Node* n) -> SROperator {
662 if (!n->matches(torch::schema(
663 "aten::view(Tensor(a) self, int[] size) -> (Tensor(a))"))) {
664 LogAndDumpSchema(n);
665 return nullptr;
666 }
667 return [](ProcessedNode* p_node) {
668 const auto& input = p_node->Input(0).toTensor();
669 const auto size = p_node->Input(1).toIntList();
670 p_node->Output(0) = at::native::view(input, size.vec());
671 };
672 });
673
674 REGISTER_NATIVE_OPERATOR_FUNCTOR(
675 aten::size,
676 aten_size,
__anon75e5f0514102(Node* n) 677 [](Node* n) -> SROperator {
678 if (n->matches(
679 torch::schema("aten::size(Tensor self, int dim) -> int"))) {
680 return [](ProcessedNode* p_node) {
681 const auto& input = p_node->Input(0).toTensor();
682 auto dim = p_node->Input(1).toInt();
683 const auto ndim = input.dim();
684
685 if (dim < 0 || dim >= ndim) {
686 dim = c10::maybe_wrap_dim(dim, ndim);
687 }
688 p_node->Output(0) = input.sizes()[dim];
689 };
690 }
691 if (n->matches(torch::schema("aten::size(Tensor self) -> int[]"))) {
692 return [](ProcessedNode* p_node) {
693 const auto& input = p_node->Input(0).toTensor();
694 p_node->Output(0) = input.sizes();
695 };
696 }
697 LogAndDumpSchema(n);
698 return nullptr;
699 });
700
701 REGISTER_NATIVE_OPERATOR_FUNCTOR(
702 aten::squeeze,
703 aten_squeeze,
__anon75e5f0514402(Node* n) 704 [](Node* n) -> SROperator {
705 if (!n->matches(torch::schema(
706 "aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)"))) {
707 LogAndDumpSchema(n);
708 return nullptr;
709 }
710
711 return [](ProcessedNode* p_node) {
712 const auto& self = p_node->Input(0).toTensor();
713 const auto dim = p_node->Input(1).toInt();
714 p_node->Output(0) = at::native::squeeze(self, dim);
715 };
716 });
717
__anon75e5f0514602(Node* n) 718 REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::split, aten_split, [](Node* n) -> SROperator {
719 if (n->matches(torch::schema(
720 "aten::split(Tensor(a -> *) self, int split_size, int dim=0) -> Tensor(a)[]"))) {
721 return [](ProcessedNode* p_node) {
722 const auto& self = p_node->Input(0).toTensor();
723 const auto split_size = p_node->Input(1).toInt();
724 const auto dim = p_node->Input(2).toInt();
725 p_node->Output(0) = at::native::split(self, split_size, dim);
726 };
727 }
728
729 if (n->matches(torch::schema(
730 "aten::split(Tensor(a -> *) self, int[] split_sizes, int dim=0) -> (Tensor[])"))) {
731 return [](ProcessedNode* p_node) {
732 const auto& self = p_node->Input(0).toTensor();
733 const auto& split_sizes = p_node->Input(1).toIntList();
734 const auto dim = p_node->Input(2).toInt();
735 p_node->Output(0) =
736 at::native::split_with_sizes(self, split_sizes.vec(), dim);
737 };
738 }
739
740 LogAndDumpSchema(n);
741 return nullptr;
742 });
743
744 REGISTER_NATIVE_OPERATOR_FUNCTOR(
745 aten::split_with_sizes,
746 aten_split_with_sizes,
__anon75e5f0514902(Node* n) 747 [](Node* n) -> SROperator {
748 if (!n->matches(torch::schema(
749 "aten::split_with_sizes(Tensor(a -> *) self, int[] split_sizes, int dim=0) -> Tensor(a)[]")) &&
750 !n->matches(torch::schema(
751 "aten::split_with_sizes(Tensor(a -> *) self, int[] split_sizes, int dim=0) -> (Tensor[])"))) {
752 LogAndDumpSchema(n);
753 return nullptr;
754 }
755 return [](ProcessedNode* p_node) {
756 const auto& self = p_node->Input(0).toTensor();
757 const auto& split_sizes = p_node->Input(1).toIntList();
758 const auto dim = p_node->Input(2).toInt();
759 p_node->Output(0) =
760 at::native::split_with_sizes(self, split_sizes.vec(), dim);
761 };
762 });
763
764 REGISTER_NATIVE_OPERATOR_FUNCTOR(
765 static_runtime::select_tensor,
766 aten_select_tensor,
__anon75e5f0514b02(Node* n) 767 [](Node* n) -> SROperator {
768 if (!sr_schema_check(
769 n,
770 "static_runtime::select_tensor(Tensor(a) a, Tensor(b) b, bool use_b) -> Tensor(a|b)")) {
771 return nullptr;
772 }
773 return [](ProcessedNode* p_node) {
774 const auto did_copy = p_node->Input(2).toBool();
775 DCHECK(p_node->Input(0).isTensor());
776 DCHECK(!did_copy || p_node->Input(1).isTensor());
777 const IValue& assignFrom =
778 did_copy ? p_node->Input(1) : p_node->Input(0);
779 // Create an IValue that borrows the input Tensor in order to
780 // save a refcount increment here and decrement in
781 // MemoryPlanner::deallocate. MemoryPlanner knows about this
782 // and will safely clean it up by using the corresponding
783 // destroyBorrow method.
784 TORCH_DCHECK_NE(&assignFrom, &p_node->Output(0));
785 // MemoryPlanner should have cleaned this up!
786 DCHECK(p_node->Output(0).isNone());
787 p_node->Output(0) =
788 IValue(c10::MaybeOwnedTraits<at::TensorBase>::createBorrow(
789 assignFrom.toTensor()));
790 };
791 });
792
793 REGISTER_NATIVE_OPERATOR_FUNCTOR(
794 aten::mul,
795 aten_mul,
__anon75e5f0514d02(Node* n) 796 [](Node* n) -> SROperator {
797 if (!n->matches(
798 torch::schema("aten::mul.left_t(t[] l, int n) -> (t[])"))) {
799 LogAndDumpSchema(n);
800 return nullptr;
801 }
802 return [](ProcessedNode* pnode) {
803 const auto& list = pnode->Input(0).toList();
804 const auto n = pnode->Input(1).toInt();
805
806 auto list_type = list.elementType();
807 auto ret = c10::impl::GenericList(list_type);
808 ret.reserve(list.size() * n);
809 for (const auto i : c10::irange(n)) {
810 (void)i;
811 for (const auto& ival : list) {
812 ret.push_back(ival);
813 }
814 }
815 pnode->Output(0) = ret;
816 };
817 });
818
819 REGISTER_NATIVE_OPERATOR_FUNCTOR(
820 aten::sub,
821 aten_sub,
__anon75e5f0514f02(Node* n) 822 [](Node* n) -> SROperator {
823 if (!n->matches(torch::schema("aten::sub.int(int a, int b) -> (int)"))) {
824 LogAndDumpSchema(n);
825 return nullptr;
826 }
827 return [](ProcessedNode* pnode) {
828 const auto a = pnode->Input(0).toInt();
829 const auto b = pnode->Input(1).toInt();
830 pnode->Output(0) = a - b;
831 };
832 });
833
834 REGISTER_NATIVE_OPERATOR_FUNCTOR(
835 aten::add,
836 aten_add,
__anon75e5f0515102(Node* n) 837 [](Node* n) -> SROperator {
838 if (n->matches(torch::schema("aten::add.t(t[] a, t[] b) -> (t[])"))) {
839 return [](ProcessedNode* pnode) {
840 const auto& a = pnode->Input(0).toList();
841 const auto& b = pnode->Input(1).toList();
842 auto ret = a.copy();
843 ret.append(b);
844 pnode->Output(0) = ret;
845 };
846 }
847
848 if (n->matches(torch::schema("aten::add.int(int a, int b) -> (int)"))) {
849 return [](ProcessedNode* pnode) {
850 const auto a = pnode->Input(0).toInt();
851 const auto b = pnode->Input(1).toInt();
852 pnode->Output(0) = a + b;
853 };
854 }
855
856 LogAndDumpSchema(n);
857 return nullptr;
858 });
859
__anon75e5f0515402(Node* n) 860 REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::tensor_split, aten_tensor_split, [](Node* n) -> SROperator {
861 if (n->matches(torch::schema(
862 "aten::tensor_split.indices(Tensor(a -> *) self, int[] indices, int dim=0) -> Tensor(a)[]"))) {
863 return [](ProcessedNode* pnode) {
864 const auto& a = pnode->Input(0).toTensor();
865 const auto& b = pnode->Input(1).toIntVector();
866 const auto c = pnode->Input(2).toInt();
867 pnode->Output(0) = at::native::tensor_split(a, b, c);
868 };
869 }
870
871 if (n->matches(torch::schema(
872 "aten::tensor_split.sections(Tensor(a -> *) self, int sections, int dim=0) -> Tensor(a)[]"))) {
873 return [](ProcessedNode* pnode) {
874 const auto& a = pnode->Input(0).toTensor();
875 const auto b = pnode->Input(1).toSymInt();
876 const auto c = pnode->Input(2).toInt();
877 pnode->Output(0) = at::native::tensor_split_sections_symint(a, b, c);
878 };
879 }
880
881 if (n->matches(torch::schema(
882 "aten::tensor_split.tensor_indices_or_sections(Tensor(a -> *) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[]"))) {
883 return [](ProcessedNode* pnode) {
884 const auto& a = pnode->Input(0).toTensor();
885 const auto& b = pnode->Input(1).toTensor();
886 const auto c = pnode->Input(2).toInt();
887 pnode->Output(0) = at::native::tensor_split(a, b, c);
888 };
889 }
890 LogAndDumpSchema(n);
891 return nullptr;
892 });
893
894 REGISTER_NATIVE_OPERATOR_FUNCTOR(
895 aten::Int,
896 aten_Int,
__anon75e5f0515802(Node* n) 897 [](Node* n) -> SROperator {
898 if (!n->matches(torch::schema("aten::Int(Tensor a) -> int"))) {
899 LogAndDumpSchema(n);
900 return nullptr;
901 }
902 return [](ProcessedNode* pnode) {
903 const auto& input = pnode->Input(0).toTensor();
904 pnode->Output(0) = at::native::item(input).toInt();
905 };
906 });
907
908 // See [Create owned refs for special values]
909 REGISTER_NATIVE_OPERATOR_FUNCTOR(
910 static_runtime::create_owned_ref,
911 static_runtime_create_owned_ref,
__anon75e5f0515a02(Node* n) 912 [](Node* n) -> SROperator {
913 if (!sr_schema_check(n, "static_runtime::create_owned_ref(...) -> ...")) {
914 return nullptr;
915 }
916 return
917 [](ProcessedNode* p_node) { p_node->Output(0) = p_node->Input(0); };
918 });
919
920 namespace {
outputsEmpty(const Block * block)921 bool outputsEmpty(const Block* block) {
922 return block->outputs().size() == 1 && block->outputs().at(0)->mustBeNone();
923 }
924
blockEmpty(const Block * block)925 bool blockEmpty(const Block* block) {
926 return block->nodes().begin() == block->nodes().end();
927 }
928
929 enum class BlockRunPlan : int8_t {
930 kRunOnlyTrueBlock,
931 kRunOnlyFalseBlock,
932 kRunBothBlocks,
933 kRunNeitherBlock,
934 };
935 } // namespace
936
937 REGISTER_NATIVE_OPERATOR_FUNCTOR(
938 prim::If,
939 prim_If,
__anon75e5f0515d02(Node* node) 940 [](Node* node) -> SROperator {
941 if (!sr_schema_check_kind(node, prim::If)) {
942 return nullptr;
943 }
944 TORCH_DCHECK_EQ(node->blocks().size(), 2);
945 const Block* true_block = node->blocks().at(0);
946 const Block* false_block = node->blocks().at(1);
947
948 const bool true_block_returns_empty = outputsEmpty(true_block);
949 const bool false_block_returns_empty = outputsEmpty(false_block);
950
951 BlockRunPlan block_run_plan = BlockRunPlan::kRunNeitherBlock;
952
953 if (true_block_returns_empty && false_block_returns_empty) {
954 const bool false_block_is_empty = blockEmpty(false_block);
955 const bool true_block_is_empty = blockEmpty(true_block);
956
957 if (false_block_is_empty && !true_block_is_empty) {
958 block_run_plan = BlockRunPlan::kRunOnlyTrueBlock;
959 } else if (!false_block_is_empty && true_block_is_empty) {
960 block_run_plan = BlockRunPlan::kRunOnlyFalseBlock;
961 } else if (false_block_is_empty && true_block_is_empty) {
962 block_run_plan = BlockRunPlan::kRunNeitherBlock;
963 } else {
964 block_run_plan = BlockRunPlan::kRunBothBlocks;
965 }
966 } else {
967 block_run_plan = BlockRunPlan::kRunBothBlocks;
968 }
969
970 switch (block_run_plan) {
971 case BlockRunPlan::kRunBothBlocks:
972 return [](ProcessedNode* p_node) {
973 auto condition = p_node->Input(0).toBool();
974 auto* metadata = p_node->metadata();
975 DCHECK(metadata);
976 auto& block_runners = metadata->block_runners();
977 TORCH_DCHECK_EQ(block_runners.size(), 2);
978 auto& runner = block_runners[!condition];
979
980 auto output = runner({});
981 // If we are returning a tuple, we are either returning
982 // multiple unpacked values or all of the values wrapped
983 // in a single tuple. The second condition handles the
984 // the latter case.
985 if (!output.isTuple() || p_node->num_outputs() == 1) {
986 p_node->Output(0) = std::move(output);
987 return;
988 }
989 auto& elems = output.toTupleRef().elements();
990 TORCH_DCHECK_EQ(elems.size(), p_node->num_outputs());
991 for (const auto i : c10::irange(elems.size())) {
992 p_node->Output(i) = elems[i];
993 }
994 };
995 case BlockRunPlan::kRunOnlyTrueBlock:
996 return [](ProcessedNode* p_node) {
997 auto condition = p_node->Input(0).toBool();
998 auto* metadata = p_node->metadata();
999 DCHECK(metadata);
1000 auto& block_runners = metadata->block_runners();
1001 TORCH_DCHECK_EQ(block_runners.size(), 2);
1002 if (condition) {
1003 auto output = block_runners.front()({});
1004 DCHECK(output.isNone());
1005 }
1006 };
1007 case BlockRunPlan::kRunOnlyFalseBlock:
1008 return [](ProcessedNode* p_node) {
1009 auto condition = p_node->Input(0).toBool();
1010 auto* metadata = p_node->metadata();
1011 DCHECK(metadata);
1012 auto& block_runners = metadata->block_runners();
1013 TORCH_DCHECK_EQ(block_runners.size(), 2);
1014 if (!condition) {
1015 auto output = block_runners.back()({});
1016 DCHECK(output.isNone());
1017 }
1018 };
1019 case BlockRunPlan::kRunNeitherBlock:
1020 return [](ProcessedNode*) {};
1021 }
1022 return [](ProcessedNode*) {};
1023 });
1024
1025 namespace {
1026
collectLoopSubBlockInputs(const ProcessedNode & p_node)1027 std::vector<IValue> collectLoopSubBlockInputs(const ProcessedNode& p_node) {
1028 const auto num_inputs = p_node.num_inputs();
1029 TORCH_DCHECK_GE(num_inputs, 2);
1030 // The first two inputs to the loop node are the max trip count
1031 // and initial condition. We don't collect them here, since those
1032 // are not inputs for the sub-block.
1033 const auto num_args = num_inputs - 2;
1034
1035 std::vector<IValue> result;
1036 result.reserve(num_args + 1);
1037 // First argument to the loop sub-block is always the loop counter, initially
1038 // zero.
1039 result.emplace_back(0);
1040
1041 for (const auto i : c10::irange(num_args)) {
1042 result.push_back(p_node.Input(2 + i));
1043 }
1044
1045 return result;
1046 }
1047
1048 } // namespace
1049
1050 namespace {
1051 /*
1052 ForkedSubgraphSRLauncher is responsible for the execution of
1053 forked subgraph on new instance of static runtime. Once the
1054 execution is completed, future is marked as complete to
1055 indicate aten::wait() to proceed
1056 */
1057 class TORCH_API ForkedSubgraphSRLauncher {
1058 public:
ForkedSubgraphSRLauncher(std::shared_ptr<StaticModule> smodule,std::vector<IValue> args,c10::intrusive_ptr<Future> future,TaskLauncher launcher)1059 ForkedSubgraphSRLauncher(
1060 std::shared_ptr<StaticModule> smodule,
1061 std::vector<IValue> args,
1062 c10::intrusive_ptr<Future> future,
1063 TaskLauncher launcher)
1064 : smodule_(std::move(smodule)),
1065 args_(std::move(args)),
1066 future_(std::move(future)),
1067 launcher_(std::move(launcher)) {}
1068
operator ()()1069 void operator()() {
1070 try {
1071 StaticRuntime runtime(*smodule_);
1072 auto future_subgraph = runtime.runAsync(args_, {}, launcher_);
1073 future_subgraph->waitAndThrow();
1074 future_->markCompleted(future_subgraph->value());
1075 } catch (const std::exception& e) {
1076 future_->setErrorIfNeeded(
1077 std::make_exception_ptr(c10::ivalue::Future::FutureError(e.what())));
1078 }
1079 }
1080
1081 private:
1082 std::shared_ptr<StaticModule> smodule_;
1083 std::vector<IValue> args_;
1084 c10::intrusive_ptr<Future> future_;
1085 torch::jit::TaskLauncher launcher_;
1086 };
1087
1088 /*
1089 helper function to create a future on return type
1090 of the graph outputs. This function is utilized by
1091 prim::fork and aten::wait operations for async
1092 execution of subgraphs
1093 */
createFutureTypeFromGraphOutput(const std::shared_ptr<torch::jit::Graph> & graph)1094 c10::intrusive_ptr<Future> createFutureTypeFromGraphOutput(
1095 const std::shared_ptr<torch::jit::Graph>& graph) {
1096 TypePtr return_type_;
1097 if (graph->outputs().size() == 1) {
1098 return_type_ = graph->outputs().at(0)->type();
1099 } else {
1100 return_type_ = TupleType::create(
1101 fmap(graph->outputs(), [](const Value* v) { return v->type(); }));
1102 }
1103 c10::intrusive_ptr<Future> future = c10::make_intrusive<Future>(return_type_);
1104 return future;
1105 }
1106 } // namespace
1107
1108 /*
1109 prim::fork forks the execution of a subgraph. It returns a future on which
1110 the corresponding aten::wait op waits until future is marked complete
1111 Current implementation creates a instance of StaticModule uses it to
1112 create StaticRuntime instances on the fly during runtime to handle the
1113 execution of forked subgraph. Async execution is handled by
1114 aten::ParallelThreadPoolNative threadpool.
1115 */
1116 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1117 prim::fork,
1118 prim_Fork,
__anon75e5f0516602(Node* node) 1119 [](Node* node) -> SROperator {
1120 if (!sr_schema_check_kind(node, prim::fork)) {
1121 return nullptr;
1122 }
1123 auto forkedGraph = node->g(attr::Subgraph);
1124 Inline(*forkedGraph);
1125 auto sr_metadata = node->ival(getStaticRuntimeMetadataSymbol())
1126 .toCustomClass<StaticRuntimeMetadata>();
1127 auto smodule =
1128 std::make_shared<StaticModule>(forkedGraph, sr_metadata->get_opts());
1129
1130 return [forkedGraph = std::move(forkedGraph),
1131 smodule = std::move(smodule)](ProcessedNode* p_node) {
1132 std::vector<IValue> args;
1133 args.reserve(p_node->num_inputs());
1134 for (const auto i : c10::irange(p_node->num_inputs())) {
1135 args.push_back(p_node->Input(i));
1136 }
1137
1138 c10::intrusive_ptr<Future> future =
1139 createFutureTypeFromGraphOutput(forkedGraph);
1140 p_node->Output(0) = future;
1141
1142 auto* metadata = p_node->metadata();
1143 DCHECK(metadata);
1144 auto* launcher = metadata->launcher();
1145 DCHECK(launcher);
1146 ForkedSubgraphSRLauncher runtime_launcher(
1147 smodule, args, future, *launcher);
1148 (*launcher)(std::move(runtime_launcher));
1149 };
1150 });
1151 /*
1152 aten::wait waits on the future (present in corresponding fork)
1153 to be executed. Once the execution is complete, the future is marked
1154 completed and wait execution continues.
1155 */
1156 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1157 aten::wait,
1158 aten_Wait,
__anon75e5f0516802(Node* n) 1159 [](Node* n) -> SROperator {
1160 if (!sr_schema_check(n, "aten::wait(Future(t) self) -> t")) {
1161 return nullptr;
1162 }
1163 return [](ProcessedNode* p_node) {
1164 TORCH_INTERNAL_ASSERT(p_node->Input(0).isFuture());
1165 auto future = p_node->Input(0).toFuture();
1166
1167 // blocking call: waiting for the future to be completed
1168 future->waitAndThrow();
1169
1170 TORCH_INTERNAL_ASSERT(future->completed());
1171 TORCH_INTERNAL_ASSERT(!future->hasError());
1172 TORCH_INTERNAL_ASSERT(future->hasValue());
1173
1174 if (!future->value().isTuple()) {
1175 p_node->Output(0) = future->value();
1176 return;
1177 }
1178 auto& elems = future->value().toTupleRef().elements();
1179 TORCH_DCHECK_EQ(elems.size(), p_node->num_outputs());
1180 for (const auto i : c10::irange(elems.size())) {
1181 p_node->Output(i) = elems[i];
1182 }
1183 };
1184 });
1185
1186 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1187 prim::Loop,
1188 prim_Loop,
__anon75e5f0516a02(Node* n) 1189 [](Node* n) -> SROperator {
1190 if (!sr_schema_check_kind(n, prim::Loop)) {
1191 return nullptr;
1192 }
1193 return [](ProcessedNode* p_node) {
1194 const auto max_trip_count = p_node->Input(0).toInt();
1195 auto condition = p_node->Input(1).toBool();
1196
1197 auto* metadata = p_node->metadata();
1198 DCHECK(metadata);
1199 auto& block_runners = metadata->block_runners();
1200 TORCH_DCHECK_EQ(block_runners.size(), 1);
1201 auto& runner = block_runners[0];
1202
1203 auto args = collectLoopSubBlockInputs(*p_node);
1204 int64_t loop_count = 0;
1205
1206 while (condition && loop_count < max_trip_count) {
1207 auto output = runner(args);
1208
1209 if (output.isTuple()) {
1210 auto& elems = output.toTupleRef().elements();
1211 DCHECK(elems.size() == args.size());
1212 for (const auto i : c10::irange(1, args.size())) {
1213 args[i] = elems[i];
1214 }
1215 condition = elems[0].toBool();
1216 } else {
1217 condition = output.toBool();
1218 }
1219 args[0] = ++loop_count;
1220 }
1221
1222 const auto num_outputs = p_node->num_outputs();
1223 TORCH_DCHECK_EQ(args.size(), num_outputs + 1);
1224 for (const auto i : c10::irange(num_outputs)) {
1225 p_node->Output(i) = std::move(args[i + 1]);
1226 }
1227 };
1228 });
1229
1230 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1231 prim::CreateObject,
1232 prim_CreateObject,
__anon75e5f0516c02(Node* node) 1233 [](Node* node) -> SROperator {
1234 if (!sr_schema_check_kind(node, prim::CreateObject)) {
1235 return nullptr;
1236 }
1237 auto class_type = node->output()->type()->expect<ClassType>();
1238 return [class_type = std::move(class_type)](ProcessedNode* pnode) {
1239 pnode->Output(0) = c10::ivalue::Object::create(
1240 c10::StrongTypePtr(class_type->compilation_unit(), class_type),
1241 class_type->numAttributes());
1242 };
1243 });
1244
1245 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1246 prim::TupleIndex,
1247 prim_TupleIndex,
__anon75e5f0516e02(Node* n) 1248 [](Node* n) -> SROperator {
1249 if (!sr_schema_check_kind(n, prim::TupleIndex)) {
1250 return nullptr;
1251 }
1252 return [](ProcessedNode* pnode) {
1253 const auto& elems = pnode->Input(0).toTupleRef().elements();
1254 using c10::ssize;
1255 const auto num_elems = ssize(elems);
1256 const auto idx = pnode->Input(1).toInt();
1257 const auto norm_idx = normalizeIndex(idx, num_elems);
1258 if (norm_idx < 0 || norm_idx >= num_elems) {
1259 // Use std::runtime_error instead of c10::Error to be consistent with
1260 // JIT
1261 throw std::out_of_range("Tuple index out of range");
1262 }
1263 pnode->Output(0) = elems[norm_idx];
1264 };
1265 });
1266
1267 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1268 prim::RaiseException,
1269 prim_RaiseException,
__anon75e5f0517002(Node* n) 1270 [](Node* n) -> SROperator {
1271 if (!sr_schema_check_kind(n, prim::RaiseException)) {
1272 return nullptr;
1273 }
1274 return [](ProcessedNode* pnode) {
1275 const auto& message = pnode->Input(0).toStringRef();
1276 throw std::runtime_error(message);
1277 };
1278 });
1279
1280 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1281 prim::Uninitialized,
1282 prim_Uninitialized,
__anon75e5f0517202(Node* n) 1283 [](Node* n) -> SROperator {
1284 if (!sr_schema_check_kind(n, prim::Uninitialized)) {
1285 return nullptr;
1286 }
1287 return [](ProcessedNode* pnode) {
1288 pnode->Output(0) = IValue::uninitialized();
1289 };
1290 });
1291
1292 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1293 aten::format,
1294 aten_format,
__anon75e5f0517402(Node* n) 1295 [](Node* n) -> SROperator {
1296 if (!sr_schema_check(n, "aten::format(str self, ...) -> str")) {
1297 return nullptr;
1298 }
1299 TORCH_CHECK(!n->inputs().empty());
1300 return [](ProcessedNode* pnode) {
1301 const auto num_inputs = pnode->num_inputs();
1302 auto stack = boxInputs(*pnode);
1303 format(stack, num_inputs);
1304 TORCH_DCHECK_EQ(stack.size(), 1);
1305 pnode->Output(0) = std::move(stack[0]);
1306 };
1307 });
1308
1309 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1310 prim::device,
1311 prim_device,
__anon75e5f0517602(Node* n) 1312 [](Node* n) -> SROperator {
1313 if (!sr_schema_check(n, "prim::device(Tensor a) -> Device")) {
1314 return nullptr;
1315 }
1316 return [](ProcessedNode* pnode) {
1317 const auto& input = pnode->Input(0).toTensor();
1318 pnode->Output(0) = input.device();
1319 };
1320 });
1321
1322 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1323 prim::dtype,
1324 prim_dtype,
__anon75e5f0517802(Node* n) 1325 [](Node* n) -> SROperator {
1326 if (!sr_schema_check_kind(n, prim::dtype)) {
1327 return nullptr;
1328 }
1329 return [](ProcessedNode* pnode) {
1330 const auto& input = pnode->Input(0).toTensor();
1331 pnode->Output(0) = static_cast<int64_t>(input.scalar_type());
1332 };
1333 });
1334
1335 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1336 aten::dim,
1337 aten_dim,
__anon75e5f0517a02(Node* n) 1338 [](Node* n) -> SROperator {
1339 if (!sr_schema_check(n, "aten::dim(Tensor self) -> int")) {
1340 return nullptr;
1341 }
1342 return [](ProcessedNode* pnode) {
1343 const auto& input = pnode->Input(0).toTensor();
1344 pnode->Output(0) = input.dim();
1345 };
1346 });
1347
1348 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1349 aten::__not__,
1350 aten_not,
__anon75e5f0517c02(Node* n) 1351 [](Node* n) -> SROperator {
1352 if (!sr_schema_check(n, "aten::__not__(bool self) -> bool")) {
1353 return nullptr;
1354 }
1355 return [](ProcessedNode* pnode) {
1356 auto input = pnode->Input(0).toBool();
1357 pnode->Output(0) = !input;
1358 };
1359 });
1360
1361 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1362 aten::Bool,
1363 aten_Bool,
__anon75e5f0517e02(Node* n) 1364 [](Node* n) -> SROperator {
1365 if (n->matches(torch::schema("aten::Bool.Tensor(Tensor a) -> bool"))) {
1366 return [](ProcessedNode* pnode) {
1367 const auto& input = pnode->Input(0).toTensor();
1368 pnode->Output(0) = at::native::is_nonzero(input);
1369 };
1370 }
1371 if (n->matches(torch::schema("aten::Bool.int(int a) -> bool"))) {
1372 return [](ProcessedNode* pnode) {
1373 const auto input = pnode->Input(0).toInt();
1374 pnode->Output(0) = static_cast<bool>(input);
1375 };
1376 }
1377 if (n->matches(torch::schema("aten::Bool.float(float a) -> bool"))) {
1378 return [](ProcessedNode* pnode) {
1379 const auto input = pnode->Input(0).toDouble();
1380 pnode->Output(0) = static_cast<bool>(input);
1381 };
1382 }
1383 LogAndDumpSchema(n);
1384 return nullptr;
1385 });
1386
1387 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1388 prim::is_cuda,
1389 prim_is_cuda,
__anon75e5f0518202(Node* n) 1390 [](Node* n) -> SROperator {
1391 if (!sr_schema_check(n, "prim::is_cuda(Tensor a) -> bool")) {
1392 return nullptr;
1393 }
1394 return [](ProcessedNode* pnode) {
1395 const auto& input = pnode->Input(0).toTensor();
1396 pnode->Output(0) = input.is_cuda();
1397 };
1398 });
1399
1400 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1401 prim::tolist,
1402 prim_tolist,
__anon75e5f0518402(Node* n) 1403 [](Node* n) -> SROperator {
1404 if (!sr_schema_check_kind(n, prim::tolist)) {
1405 return nullptr;
1406 }
1407 return [](ProcessedNode* pnode) {
1408 const auto& input = pnode->Input(0).toTensor();
1409 const auto dim = pnode->Input(1).toInt();
1410 const auto elem_type = pnode->Input(2).toInt();
1411 std::vector<IValue> stack{input, dim, elem_type};
1412 toList(stack);
1413 TORCH_DCHECK_EQ(stack.size(), 1);
1414 pnode->Output(0) = std::move(stack[0]);
1415 };
1416 });
1417
1418 // See [Borrowed IValue Outputs]
1419 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1420 prim::IfThenElse,
1421 prim_IfThenElse,
__anon75e5f0518602(Node* n) 1422 [](Node* n) -> SROperator {
1423 if (!sr_schema_check_kind(n, prim::IfThenElse)) {
1424 return nullptr;
1425 }
1426 return [](ProcessedNode* pnode) {
1427 const auto condition = pnode->Input(0).toBool();
1428 pnode->Output(0) = condition ? createBorrowedIValue(pnode->Input(1))
1429 : createBorrowedIValue(pnode->Input(2));
1430 };
1431 });
1432
1433 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1434 aten::len,
1435 aten_len,
__anon75e5f0518802(Node* n) 1436 [](Node* n) -> SROperator {
1437 if (n->matches(torch::schema("aten::len.t(t[] a) -> int")) ||
1438 n->matches(torch::schema("aten::len.any(Any[] a) -> int"))) {
1439 return [](ProcessedNode* pnode) {
1440 const auto list = pnode->Input(0).toListRef();
1441 const int64_t size = list.size();
1442 pnode->Output(0) = size;
1443 };
1444 }
1445 if (n->matches(torch::schema("aten::len.Tensor(Tensor t) -> int"))) {
1446 return [](ProcessedNode* pnode) {
1447 const auto& t = pnode->Input(0).toTensor();
1448 TORCH_CHECK(t.dim() > 0);
1449 pnode->Output(0) = t.sizes()[0];
1450 };
1451 }
1452 if (n->matches(torch::schema("aten::len.str(str s) -> int"))) {
1453 return [](ProcessedNode* pnode) {
1454 const auto& string = pnode->Input(0).toStringRef();
1455 pnode->Output(0) = static_cast<int64_t>(string.size());
1456 };
1457 }
1458 if (n->matches(
1459 torch::schema("aten::len.Dict_str(Dict(str, t) self) -> int")) ||
1460 n->matches(
1461 torch::schema("aten::len.Dict_int(Dict(int, t) self) -> int")) ||
1462 n->matches(torch::schema(
1463 "aten::len.Dict_bool(Dict(bool, t) self) -> int")) ||
1464 n->matches(torch::schema(
1465 "aten::len.Dict_float(Dict(float, t) self) -> int")) ||
1466 n->matches(torch::schema(
1467 "aten::len.Dict_complex(Dict(complex, t) self) -> int")) ||
1468 n->matches(torch::schema(
1469 "aten::len.Dict_Tensor(Dict(Tensor, t) self) -> int"))) {
1470 return [](ProcessedNode* pnode) {
1471 const auto& dict = pnode->Input(0).toGenericDict();
1472 pnode->Output(0) = static_cast<int64_t>(dict.size());
1473 };
1474 }
1475 LogAndDumpSchema(n);
1476 return nullptr;
1477 });
1478
1479 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1480 aten::IntImplicit,
1481 aten_IntImplicit,
__anon75e5f0518d02(Node* n) 1482 [](Node* n) -> SROperator {
1483 if (!n->matches(torch::schema("aten::IntImplicit(Tensor a) -> int"))) {
1484 LogAndDumpSchema(n);
1485 return nullptr;
1486 }
1487 return [](ProcessedNode* pnode) {
1488 const auto& tensor = pnode->Input(0).toTensor();
1489 // JIT does a check for requires_grad, but we skip it here since SR is
1490 // inference only
1491 if (!tensor.sizes().empty()) {
1492 throw std::runtime_error(
1493 "Cannot convert a tensor of dimension > 0 to scalar");
1494 }
1495 if (!isIntegralType(tensor.scalar_type(), /*includeBool=*/false)) {
1496 std::stringstream ss;
1497 ss << "Cannot input a tensor of type " << tensor.scalar_type()
1498 << " as an integral argument";
1499 throw std::runtime_error(ss.str());
1500 }
1501 pnode->Output(0) = at::native::item(tensor).toInt();
1502 };
1503 });
1504
1505 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1506 aten::select,
1507 aten_select,
__anon75e5f0518f02(Node* n) 1508 [](Node* n) -> SROperator {
1509 if (!n->matches(torch::schema(
1510 "aten::select(Tensor(a) self, int dim, int index) -> Tensor(a)"))) {
1511 LogAndDumpSchema(n);
1512 return nullptr;
1513 }
1514 return [](ProcessedNode* pnode) {
1515 const auto& self = pnode->Input(0).toTensor();
1516 const auto dim = pnode->Input(1).toInt();
1517 const auto index = pnode->Input(2).toInt();
1518 pnode->Output(0) = at::native::select(self, dim, index);
1519 };
1520 });
1521
1522 REGISTER_NATIVE_OPERATOR_FUNCTOR(
1523 aten::reshape_as,
1524 aten_reshape_as,
__anon75e5f0519102(Node* n) 1525 [](Node* n) -> SROperator {
1526 if (!n->matches(torch::schema(
1527 "aten::reshape_as(Tensor(a) self, Tensor other) -> Tensor(a)"))) {
1528 LogAndDumpSchema(n);
1529 return nullptr;
1530 }
1531 return [](ProcessedNode* pnode) {
1532 const auto& self = pnode->Input(0).toTensor();
1533 const auto& other = pnode->Input(1).toTensor();
1534 pnode->Output(0) = at::native::reshape(self, other.sizes());
1535 };
1536 });
1537
1538 } // namespace torch::jit
1539