1 #include <torch/csrc/jit/python/python_ir.h>
2
3 #include <ATen/core/jit_type.h>
4 #include <pybind11/pybind11.h>
5 #include <torch/csrc/Device.h>
6 #include <torch/csrc/Dtype.h>
7 #include <torch/csrc/api/include/torch/python.h>
8 #include <torch/csrc/jit/ir/alias_analysis.h>
9 #include <torch/csrc/jit/ir/ir.h>
10 #include <torch/csrc/jit/passes/canonicalize.h>
11 #include <torch/csrc/jit/passes/onnx/helper.h>
12 #include <torch/csrc/jit/passes/shape_analysis.h>
13 #include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
14 #include <torch/csrc/jit/python/pybind.h>
15 #include <torch/csrc/jit/python/python_tracer.h>
16 #include <torch/csrc/jit/runtime/argument_spec.h>
17 #include <torch/csrc/jit/serialization/export.h>
18 #include <torch/csrc/jit/serialization/python_print.h>
19 #include <torch/csrc/python_headers.h>
20 #include <torch/csrc/utils/pybind.h>
21 #include <torch/csrc/utils/python_strings.h>
22 #include <iostream>
23 #include <sstream>
24 #include <utility>
25
26 namespace torch::jit {
27
28 // Controls whether graph source ranges are printed by default
29 bool global_print_source_ranges = true;
30
31 Symbol ConcretePythonOp::Kind = prim::PythonOp;
32
33 using c10::Type;
34
getPythonName(const PyObject * obj_)35 std::string getPythonName(const PyObject* obj_) {
36 pybind11::gil_scoped_acquire gil;
37 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
38 PyObject* obj = const_cast<PyObject*>(obj_);
39 auto v = py::getattr(obj, "__name__", py::str("<python_value>"));
40 // if this was a autograd.Function recover the name of the class
41 return py::str(v);
42 }
43
printPyObject(std::ostream & out,const THPObjectPtr & obj)44 std::ostream& printPyObject(std::ostream& out, const THPObjectPtr& obj) {
45 pybind11::gil_scoped_acquire gil;
46 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
47 auto pyobj = py::handle(const_cast<PyObject*>(obj.get()));
48 if (py::isinstance<py::tuple>(pyobj)) {
49 // This special-case for printing tuples handles a problem where
50 // str((2L, 3L)) outputs "(2L, 3L)" in Python 2 but "(2, 3)"
51 // in Python 3. In order to suppress the L-suffix, we must
52 // manually print the string ourselves, calling str() on the
53 // sub-elements.
54 //
55 // This is a fairly fragile fix (What if you have nested tuples
56 // in tuples? What if you have dictionaries?) but it seems to hit
57 // the cases that are triggered in practice in onnx-pytorch. Revisit
58 // this code if this is not the case.
59 //
60 // By the way, one non-solution for this problem is to monkeypatch
61 // tuple.__str__; this doesn't work because Python doesn't allow
62 // monkeypatching methods of built-in types.
63 auto pytuple = pyobj.cast<py::tuple>();
64 out << "(";
65 size_t i = 0;
66 for (const auto& o : pytuple) {
67 if (i > 0) {
68 out << ", ";
69 }
70 THPObjectPtr str(py::str(o).release().ptr());
71 out << THPUtils_unpackString(str.get());
72 i++;
73 }
74 if (i == 1) {
75 out << ",";
76 }
77 out << ")";
78 return out;
79 } else {
80 return out << THPUtils_unpackString(py::str(pyobj).ptr());
81 }
82 }
83
findNode(c10::ArrayRef<torch::jit::Block * > blocks,Symbol kind,bool recurse=true)84 Node* findNode(
85 c10::ArrayRef<torch::jit::Block*> blocks,
86 Symbol kind,
87 bool recurse = true) {
88 for (Block* block : blocks) {
89 for (Node* n : block->nodes()) {
90 if (n->kind() == kind) {
91 return n;
92 }
93 if (recurse) {
94 auto node = findNode(n->blocks(), kind, recurse);
95 if (node != nullptr) {
96 return node;
97 }
98 }
99 }
100 }
101 return nullptr;
102 }
103
findNode(Block * block,Symbol kind,bool recurse=true)104 Node* findNode(Block* block, Symbol kind, bool recurse = true) {
105 std::vector<Block*> blocks = {block};
106 return findNode(blocks, kind, recurse);
107 }
108
name() const109 std::string ConcretePythonOp::name() const {
110 pybind11::gil_scoped_acquire gil;
111 if (auto autograd = autogradFunction()) {
112 return getPythonName(autograd->get());
113 } else {
114 return getPythonName(pyobj.get());
115 }
116 }
117
cloneFrom(Node * other_)118 void ConcretePythonOp::cloneFrom(Node* other_) {
119 // NOLINTNEXTLINE(bugprone-parent-virtual-call)
120 Node::cloneFrom(other_);
121 auto other = other_->cast<ConcretePythonOp>();
122 this->cconv = other->cconv;
123 Py_INCREF(other->pyobj.get());
124 this->pyobj = THPObjectPtr(other->pyobj.get());
125 for (auto& sa : other->scalar_args) {
126 Py_INCREF(sa.get());
127 this->scalar_args.emplace_back(sa.get());
128 }
129 }
130
131 // recover the autograd.Function instance, if this PythonOp's function
132 // was originally SomeFunction.apply
133 // used in ONNX for discovering symbolics
autogradFunction() const134 std::optional<THPObjectPtr> ConcretePythonOp::autogradFunction() const {
135 pybind11::gil_scoped_acquire gil;
136 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
137 py::handle obj = const_cast<PyObject*>(pyobj.get());
138
139 auto r = py::getattr(obj, "__self__", py::none());
140 if (r.is_none())
141 return std::nullopt;
142
143 auto apply = py::getattr(r, "apply", py::none());
144 if (apply.is_none())
145 return std::nullopt;
146
147 auto c = PyObject_RichCompareBool(apply.ptr(), obj.ptr(), Py_NE);
148 if (PyErr_Occurred())
149 throw py::error_already_set();
150 if (c)
151 return std::nullopt;
152
153 return THPObjectPtr(r.release().ptr());
154 }
155
writeScalars(std::ostream & out) const156 void ConcretePythonOp::writeScalars(std::ostream& out) const {
157 out << "(";
158 int i = 0;
159 for (auto& scalar : scalar_args) {
160 if (i++ > 0)
161 out << ", ";
162 printPyObject(out, scalar);
163 }
164 out << ")";
165 }
166
lint_python() const167 void ConcretePythonOp::lint_python() const {
168 size_t n_scalars = 0, n_tensors = 0;
169 for (auto c : cconv) {
170 if (c == 'c') {
171 n_scalars++;
172 } else if (c == 'd') {
173 n_tensors++;
174 } else {
175 AT_ASSERT(0);
176 }
177 AT_ASSERT(static_cast<bool>(pyobj));
178 }
179 AT_ASSERT(n_scalars == scalar_args.size());
180 AT_ASSERT(n_tensors == inputs().size());
181 }
182
createPythonOp(THPObjectPtr && pyobj,const std::string & cconv,pyobj_list && scalar_args)183 Node* Graph::createPythonOp(
184 THPObjectPtr&& pyobj,
185 const std::string& cconv,
186 pyobj_list&& scalar_args) {
187 ConcretePythonOp* op = new ConcretePythonOp(this);
188 return op->init(std::move(pyobj), cconv, std::move(scalar_args));
189 }
190
initPythonIRBindings(PyObject * module_)191 void initPythonIRBindings(PyObject* module_) {
192 auto m = py::handle(module_).cast<py::module>();
193
194 py::class_<AliasDb, std::shared_ptr<AliasDb>>(m, "AliasDb")
195 .def("dump", &AliasDb::dump)
196 .def("to_graphviz_str", &AliasDb::toGraphviz)
197 .def(
198 "may_contain_alias",
199 [&](AliasDb& db, Value* v1, Value* v2) {
200 return db.mayContainAlias(v1, v2);
201 })
202 .def(
203 "has_writers",
204 [&](AliasDb& db, Value* v1) { return db.hasWriters(v1); })
205 .def("__str__", &AliasDb::toString)
206 .def(
207 "move_after_topologically_valid",
208 [](AliasDb& db, Node* n, Node* movePoint) {
209 return db.moveAfterTopologicallyValid(n, movePoint);
210 })
211 .def(
212 "move_before_topologically_valid",
213 [](AliasDb& db, Node* n, Node* movePoint) {
214 return db.moveBeforeTopologicallyValid(n, movePoint);
215 });
216 #define GS(name) def(#name, &Graph ::name)
217 py::class_<Graph, std::shared_ptr<Graph>>(m, "Graph")
218 .def(py::init<>())
219 .def(
220 "__repr__",
221 [&](Graph& g) { return g.toString(global_print_source_ranges); })
222 .def("str", &Graph::toString, py::arg("print_source_ranges") = true)
223 .def_readonly_static(
224 "global_print_source_ranges", &global_print_source_ranges)
225 .def_static(
226 "set_global_print_source_ranges",
227 [&](const bool enabled) { global_print_source_ranges = enabled; },
228 py::arg("enabled") = true)
229 .def(
230 "alias_db",
231 [](std::shared_ptr<Graph> g,
232 bool isFrozen = false,
233 bool descend_function_calls = false) {
234 return std::make_shared<AliasDb>(
235 std::move(g), isFrozen, descend_function_calls);
236 },
237 py::arg("isFrozen") = false,
238 py::arg("descend_function_calls") = false)
239 .def(
240 "dump_alias_db",
241 [](std::shared_ptr<Graph> g) {
242 AliasDb db(std::move(g));
243 db.dump();
244 })
245 .def(
246 "_export_onnx",
247 [](const std::shared_ptr<Graph>& g,
248 const std::map<std::string, at::Tensor>& initializers,
249 int64_t onnx_opset_version,
250 const std::unordered_map<
251 std::string,
252 std::unordered_map<int64_t, std::string>>& dynamic_axes,
253 bool defer_weight_export,
254 ::torch::onnx::OperatorExportTypes operator_export_type,
255 bool strip_doc_string,
256 bool keep_initializers_as_inputs,
257 const std::map<std::string, int>& custom_opsets,
258 bool add_node_names,
259 const std::string& onnx_file_path,
260 const NodeAttrNameMap& node_attr_to_name) {
261 std::string graph;
262 auto
263 [model_proto,
264 export_map,
265 symbol_map,
266 val_use_external_data_format,
267 onnx_node_names] =
268 export_onnx(
269 g,
270 initializers,
271 onnx_opset_version,
272 dynamic_axes,
273 defer_weight_export,
274 operator_export_type,
275 strip_doc_string,
276 keep_initializers_as_inputs,
277 custom_opsets,
278 add_node_names,
279 false,
280 onnx_file_path,
281 node_attr_to_name);
282 std::unordered_map<std::string, py::bytes>
283 python_serialized_export_map;
284 for (auto& kv : export_map) {
285 auto t = kv.second;
286 size_t copy_bytes = t.element_size() * t.numel();
287 // TODO: this is an unnecessary copy. In theory we can directly
288 // return the map from identifier to Tensor, but we need some API
289 // in Python to get raw `bytes` containing the raw tensor data.
290 python_serialized_export_map[kv.first] =
291 py::bytes(static_cast<const char*>(t.data_ptr()), copy_bytes);
292 }
293 graph = serialize_model_proto_to_string(model_proto);
294 return std::make_tuple(
295 py::bytes(graph),
296 python_serialized_export_map,
297 val_use_external_data_format,
298 onnx_node_names);
299 },
300 py::arg("initializers"),
301 py::arg("onnx_opset_version") = 0,
302 py::arg("dynamic_axes"),
303 py::arg("defer_weight_export") = false,
304 py::arg("operator_export_type") =
305 ::torch::onnx::OperatorExportTypes::ONNX,
306 py::arg("strip_doc_string") = true,
307 py::arg("keep_initializers_as_inputs") = true,
308 py::arg("custom_opsets"),
309 py::arg("add_node_names") = true,
310 py::arg("onnx_file_path") = std::string(),
311 py::arg("node_attr_to_name") = NodeAttrNameMap())
312 .def(
313 "_pretty_print_onnx",
314 [](const std::shared_ptr<Graph>& g,
315 const std::map<std::string, at::Tensor>& initializers,
316 int64_t onnx_opset_version,
317 bool defer_weight_export,
318 ::torch::onnx::OperatorExportTypes operator_export_type,
319 bool google_printer,
320 bool keep_initializers_as_inputs,
321 const std::map<std::string, int>& custom_opsets,
322 bool add_node_names) {
323 return pretty_print_onnx(
324 g,
325 initializers,
326 onnx_opset_version,
327 defer_weight_export,
328 operator_export_type,
329 google_printer,
330 keep_initializers_as_inputs,
331 custom_opsets,
332 add_node_names);
333 },
334 py::arg("initializers"),
335 py::arg("onnx_opset_version") = 0,
336 py::arg("defer_weight_export") = false,
337 py::arg("operator_export_type") =
338 ::torch::onnx::OperatorExportTypes::ONNX,
339 py::arg("google_printer") = false,
340 py::arg("keep_initializers_as_inputs") = true,
341 py::arg("custom_opsets"),
342 py::arg("add_node_names") = true)
343 .def(
344 "inputs",
345 [](Graph& g) {
346 return py::make_iterator(g.inputs().begin(), g.inputs().end());
347 },
348 py::keep_alive<0, 1>())
349 .def(
350 "outputs",
351 [](Graph& g) {
352 return py::make_iterator(g.outputs().begin(), g.outputs().end());
353 },
354 py::keep_alive<0, 1>())
355 // We keep the graph alive while the iterator lives. Destroying
356 // nodes might still be hazardous.
357 .def(
358 "nodes",
359 [](Graph& g) {
360 return py::make_iterator(g.nodes().begin(), g.nodes().end());
361 },
362 py::keep_alive<0, 1>())
363 .def(
364 "findNode",
365 [](Graph& g, const std::string& kind, bool recurse) {
366 return findNode(g.block(), Symbol::fromQualString(kind), recurse);
367 },
368 "Find Node",
369 py::arg("kind"),
370 py::arg("recurse") = true)
371 .def(
372 "findAllNodes",
373 [](Graph& g, const std::string& kind, bool recurse) {
374 return findAllNodes(g, Symbol::fromQualString(kind), recurse);
375 },
376 "Find all nodes",
377 py::arg("kind"),
378 py::arg("recurse") = true)
379 .def(
380 "addInput",
381 [](Graph& g, const std::string& name) { return g.addInput(name); },
382 "Add input to graph with optional name seed",
383 py::arg("name") = "")
384 .def("copy", [](Graph& g) { return g.copy(); })
385 .GS(eraseInput)
386 .GS(eraseOutput)
387 .GS(registerOutput)
388 .def(
389 "permuteInputs",
390 [](Graph& g, const std::vector<size_t>& new_inputs) {
391 g.block()->permuteInputs(new_inputs);
392 })
393 .def(
394 "create",
395 [](Graph& g, const char* str) {
396 return g.create(Symbol::fromQualString(str));
397 })
398 .def(
399 "create",
400 [](Graph& g, const char* str, size_t noutputs) {
401 return g.create(Symbol::fromQualString(str), noutputs);
402 })
403 .def(
404 "create",
405 [](Graph& g, const char* str, const std::vector<Value*>& inputs) {
406 TORCH_CHECK_VALUE(
407 std::all_of(
408 inputs.begin(),
409 inputs.end(),
410 [](Value* v) { return (v != nullptr); }),
411 "cannot pass None in inputs");
412 return g.create(Symbol::fromQualString(str), inputs);
413 })
414 .def(
415 "create",
416 [](Graph& g,
417 const char* str,
418 const std::vector<Value*>& inputs,
419 size_t noutputs) {
420 TORCH_CHECK_VALUE(
421 std::all_of(
422 inputs.begin(),
423 inputs.end(),
424 [](Value* v) { return (v != nullptr); }),
425 "cannot pass None in inputs");
426 return g.create(Symbol::fromQualString(str), inputs, noutputs);
427 })
428 .def("param_node", [](Graph& g) { return g.block()->param_node(); })
429 .def("return_node", [](Graph& g) { return g.block()->return_node(); })
430 .def(
431 "createFusionGroup",
432 [](Graph& g) { return g.createWithSubgraph(prim::FusionGroup); })
433 .def(
434 "createCudaFusionGroup",
435 [](Graph& g) { return g.createWithSubgraph(prim::CudaFusionGroup); })
436 .def(
437 "createClone",
438 [](Graph& g, Node* n, py::object fn) {
439 return g.createClone(
440 n, [&](Value* e) { return fn(e).cast<Value*>(); });
441 })
442 .GS(appendNode)
443 .GS(prependNode)
444 // NB: insert_point_guard defined over direct modification of insert point
445 .def(
446 "insert_point_guard",
447 [](Graph& g, Node* n) {
448 return py::module::import("torch.jit._ir_utils")
449 .attr("insert_point_guard")(g, n);
450 })
451 .def(
452 "insert_point_guard",
453 [](Graph& g, Block* b) {
454 return py::module::import("torch.jit._ir_utils")
455 .attr("insert_point_guard")(g, b);
456 })
457 .GS(insertPoint)
458 .def("setInsertPoint", [](Graph& g, Node* n) { g.setInsertPoint(n); })
459 .def("setInsertPoint", [](Graph& g, Block* n) { g.setInsertPoint(n); })
460 .def(
461 "insertGraph",
462 [](Graph& g, Graph& callee, const std::vector<Value*>& inputs) {
463 return insertGraph(g, callee, inputs);
464 })
465 .def(
466 "insertGraph",
467 [](Graph& g,
468 Graph& callee,
469 const std::vector<Value*>& inputs,
470 std::unordered_map<Value*, Value*> value_map) {
471 return insertGraph(g, callee, inputs, value_map);
472 })
473 .def(
474 "insert",
475 [](Graph& g, Symbol opname, const std::vector<Value*>& args) {
476 std::vector<NamedValue> args_named;
477 args_named.reserve(args.size());
478 for (Value* v : args) {
479 args_named.emplace_back(v);
480 }
481 return g.insert(opname, args_named);
482 })
483 .def(
484 "makeMultiOutputIntoTuple",
485 [](Graph& g) {
486 auto tup = g.createTuple(g.outputs());
487 tup->insertBefore(g.return_node());
488 for (int64_t i = static_cast<int64_t>(g.outputs().size()) - 1;
489 i >= 0;
490 i--) {
491 g.eraseOutput(0);
492 }
493 g.registerOutput(tup->output());
494 })
495 .def(
496 "insertConstant",
497 [](Graph& g, const IValue& ival) { return g.insertConstant(ival); })
498 .GS(lint)
499 .def("block", [](Graph& g) { return g.block(); })
500 .GS(insertNode);
501 #undef GS
502
503 #define VS(name) def(#name, &Value ::name)
504 py::class_<Value, unwrapping_shared_ptr<Value>>(m, "Value")
505 .def(
506 "__repr__",
507 [](Value& n) {
508 std::stringstream ss;
509 ss << n.debugName() << " defined in (" << *n.node() << ")";
510 return ss.str();
511 })
512 .VS(type)
513 .VS(setType)
514 .def(
515 "inferTypeFrom",
516 py::overload_cast<const at::Tensor&>(&Value::inferTypeFrom))
517 .def(
518 "inferTypeFrom",
519 py::overload_cast<const c10::intrusive_ptr<c10::ivalue::Object>&>(
520 &Value::inferTypeFrom))
521 // skip owningGraph because it returns a raw pointer to a otherwise
522 // std::shared_ptr stored graph object, and would cause a double free
523 .VS(unique)
524 .VS(debugName)
525 .VS(setDebugName)
526 .VS(offset)
527 .VS(uses)
528 .VS(replaceAllUsesWith)
529 .VS(replaceAllUsesAfterNodeWith)
530 .def("node", [](Value& v) { return v.node(); })
531 .def(
532 "setTypeAs",
533 [](Value* node, Value* other) {
534 node->setType(other->type());
535 return node;
536 })
537 .VS(copyMetadata)
538 .VS(isCompleteTensor)
539 .VS(requires_grad)
540 .def(
541 "requiresGrad",
542 [](Value& n) {
543 return n.type()->expectRef<TensorType>().requiresGrad();
544 })
545 .def("toIValue", [](Value& n) { return toIValue(&n); })
546 .def("type", [](Value& v) { return v.type(); });
547 #undef VS
548
549 py::class_<Block, unwrapping_shared_ptr<Block>>(m, "Block")
550 .def(
551 "nodes",
552 [](Block& b) {
553 return py::make_iterator(b.nodes().begin(), b.nodes().end());
554 })
555 .def(
556 "findNode",
557 [](Block& b, const std::string& kind, bool recurse) {
558 return findNode(&b, Symbol::fromQualString(kind), recurse);
559 },
560 "Find Node",
561 py::arg("kind"),
562 py::arg("recurse") = true)
563 .def(
564 "findAllNodes",
565 [](Block& b, const std::string& kind, bool recurse) {
566 return findAllNodes(b, Symbol::fromQualString(kind), recurse);
567 },
568 "Find all nodes",
569 py::arg("kind"),
570 py::arg("recurse") = true)
571 .def(
572 "inputs",
573 [](Block& b) {
574 return py::make_iterator(b.inputs().begin(), b.inputs().end());
575 })
576 .def(
577 "outputs",
578 [](Block& b) {
579 return py::make_iterator(b.outputs().begin(), b.outputs().end());
580 })
581 .def("returnNode", [](Block& b) { return b.return_node(); })
582 .def("paramNode", [](Block& b) { return b.param_node(); })
583 .def("owningNode", [](Block& b) { return b.owningNode(); })
584 .def(
585 "addNode",
586 [](Block& b, const char* str, const std::vector<Value*>& inputs) {
587 return addNodeToBlock(&b, Symbol::fromQualString(str), inputs);
588 })
589 .def("addInputToBlock", [](Block& b) { return addInputToBlock(&b); })
590 .def("registerOutput", [](Block& b, Value* value) {
591 return b.registerOutput(value);
592 });
593
594 #define NS(name) def(#name, &Node ::name)
595 py::class_<Node, unwrapping_shared_ptr<Node>>(m, "Node")
596 .def(
597 "__repr__",
598 [](Node& n) {
599 std::stringstream ss;
600 ss << n;
601 return ss.str();
602 })
603 .def("sourceRange", [](Node& n) { return n.sourceRange().str(); })
604 .def("hasMultipleOutputs", [](Node& n) { return n.outputs().size() > 1; })
605 .def("inputsSize", [](Node& n) { return n.inputs().size(); })
606 .def("outputsSize", [](Node& n) { return n.outputs().size(); })
607 .NS(kind)
608 .def("prev", [](Node& n) { return n.prev(); })
609 .def("matches", [](Node& n, const char* s) { return n.matches(s); })
610 .def("owningBlock", [](Node& n) { return n.owningBlock(); })
611 .def("inputsAt", [](Node& n, size_t i) { return n.inputs().at(i); })
612 .def(
613 "inputs",
614 [](Node& n) {
615 return py::make_iterator(n.inputs().begin(), n.inputs().end());
616 })
617 .def(
618 "schema",
619 [](Node& n) {
620 std::stringstream ss;
621 if (n.maybeSchema()) {
622 ss << n.schema();
623 } else {
624 ss << "(no schema)";
625 }
626 return ss.str();
627 })
628 .def(
629 "outputs",
630 [](Node& n) {
631 return py::make_iterator(n.outputs().begin(), n.outputs().end());
632 })
633 .def("outputsAt", [](Node& n, size_t i) { return n.outputs().at(i); })
634 .def(
635 "findNode",
636 [](Node& n, const std::string& kind, bool recurse) {
637 return findNode(n.blocks(), Symbol::fromQualString(kind), recurse);
638 },
639 "Find Node",
640 py::arg("kind"),
641 py::arg("recurse") = true)
642 .def(
643 "findAllNodes",
644 [](Node& n, const std::string& kind, bool recurse) {
645 return findAllNodes(
646 n.blocks(), Symbol::fromQualString(kind), recurse);
647 },
648 "Find all nodes",
649 py::arg("kind"),
650 py::arg("recurse") = true)
651 .def("input", [](Node& n) { return n.input(); })
652 .def("output", [](Node& n) { return n.output(); })
653 .def(
654 "getModuleHierarchy",
655 [](Node& n) { return torch::jit::utils::getNodesModuleHierarchy(n); })
656 .def(
657 "namedInput",
658 [](Node& n, const std::string& unqualName) {
659 return n.namedInput(unqualName);
660 })
661 .NS(addInput)
662 .NS(copyMetadata)
663 .NS(replaceInput)
664 .NS(replaceInputWith)
665 .NS(replaceAllUsesWith)
666 .NS(insertBefore)
667 .NS(insertAfter)
668 .NS(isBefore)
669 .NS(isAfter)
670 .NS(moveAfter)
671 .NS(moveBefore)
672 .NS(removeInput)
673 .NS(removeAllInputs)
674 .NS(destroy)
675 .NS(hasUses)
676 .NS(eraseOutput)
677 .NS(addOutput)
678 .NS(scopeName)
679 .NS(isNondeterministic)
680 .def(
681 "blocks",
682 [](Node& n) {
683 return py::make_iterator(n.blocks().begin(), n.blocks().end());
684 })
685 .NS(addBlock)
686 .NS(mustBeNone)
687
688 #define AS(name) def(#name, &Node::name)
689 // methods from Attributes
690 .AS(copyAttributes)
691 .AS(hasAttributes)
692 #undef AS
693 #define AS(name) def(#name, &Node::name##S)
694 // The default method names take Symbol, but the string conversion for
695 // Symbol you to qualify with attr::. This is not very user friendly
696 // for attributes, so expose the string variants instead.
697 .AS(hasAttribute)
698 .AS(kindOf)
699 .AS(removeAttribute)
700 .AS(attributeNames)
701 #undef AS
702 #define CREATE_ACCESSOR(Kind, method) \
703 def(#method "_", [](Node& n, const char* name, Kind##Attr::ValueType v) { \
704 return n.method##_(Symbol::attr(name), std::move(v)); \
705 }).def(#method, [](Node& n, const char* name) { \
706 return n.method(Symbol::attr(name)); \
707 })
708 .CREATE_ACCESSOR(Float, f)
709 .CREATE_ACCESSOR(Floats, fs)
710 .CREATE_ACCESSOR(Complex, c)
711 .CREATE_ACCESSOR(String, s)
712 .CREATE_ACCESSOR(Strings, ss)
713 .CREATE_ACCESSOR(Int, i)
714 .CREATE_ACCESSOR(Ints, is)
715 .CREATE_ACCESSOR(Graph, g)
716 .CREATE_ACCESSOR(Graphs, gs)
717 .CREATE_ACCESSOR(IValue, ival)
718 #undef CREATE_ACCESSOR
719 // Tensor (t_) -- manually written to unwrap the variable into a tensor.
720 .def(
721 "t_",
722 [](Node& n, const char* name, const torch::autograd::Variable& v) {
723 AT_ASSERT(!v.requires_grad());
724 return n.t_(Symbol::attr(name), v);
725 })
726 .def(
727 "t",
728 [](Node& n, const char* name) { return n.t(Symbol::attr(name)); })
729 // Tensors (ts_) -- manually written to unwrap variables into tensors.
730 .def(
731 "ts_",
732 [](Node& n,
733 const char* name,
734 const std::vector<torch::autograd::Variable>& vs) {
735 std::vector<at::Tensor> tensors;
736 tensors.reserve(vs.size());
737 for (auto& variable : vs) {
738 AT_ASSERT(!variable.requires_grad());
739 tensors.push_back(variable);
740 }
741 return n.ts_(Symbol::attr(name), std::move(tensors));
742 })
743 .def(
744 "ts",
745 [](Node& n, const char* name) {
746 auto tensors = n.ts(Symbol::attr(name));
747 std::vector<torch::autograd::Variable> variables;
748 variables.reserve(tensors.size());
749 for (auto& tensor : tensors) {
750 variables.emplace_back(std::move(tensor));
751 }
752 return variables;
753 })
754 .def(
755 "z_",
756 [](Node& n, const char* name, const at::Tensor& v) {
757 return n.t_(
758 Symbol::attr(name),
759 autograd::Variable(v.view(std::vector<int64_t>{}))
760 .set_requires_grad(false));
761 })
762 .def(
763 "z",
764 [](Node& n, const char* name) { return n.t(Symbol::attr(name)); })
765 .def(
766 "ty_",
767 [](Node& n, const char* name, const TypePtr& type) {
768 return n.ty_(Symbol::attr(name), type);
769 })
770 .def(
771 "ty",
772 [](Node& n, const char* name) { return n.ty(Symbol::attr(name)); })
773 .def(
774 "tys_",
775 [](Node& n, const char* name, const std::vector<TypePtr>& types) {
776 return n.tys_(Symbol::attr(name), types);
777 })
778 .def(
779 "tys",
780 [](Node& n, const char* name) { return n.tys(Symbol::attr(name)); })
781 .def(
782 "zs_",
783 [](Node& n, const char* name, TensorsAttr::ValueType v) {
784 for (auto& i : v) {
785 i = autograd::Variable(i.view(std::vector<int64_t>{}))
786 .set_requires_grad(false);
787 }
788 return n.ts_(Symbol::attr(name), std::move(v));
789 })
790 .def(
791 "zs",
792 [](Node& n, const char* name) { return n.ts(Symbol::attr(name)); })
793 .def(
794 "pyobj",
795 [](Node& n) {
796 return py::handle(n.expect<ConcretePythonOp>()->pyobj.get())
797 .cast<py::object>();
798 })
799 .def("cconv", [](Node& n) { return n.expect<ConcretePythonOp>()->cconv; })
800 .def(
801 "pyname",
802 [](Node& n) { return n.expect<ConcretePythonOp>()->name(); })
803 .def("scalar_args", [](Node& n) {
804 auto op = n.expect<ConcretePythonOp>();
805 auto scalars = py::list();
806 auto append = scalars.attr("append");
807 for (auto& arg : op->scalar_args) {
808 append(py::handle(arg.get()));
809 }
810 return scalars;
811 });
812
813 using ::c10::Type;
814 py::class_<Type, TypePtr>(m, "Type")
815 .def("__repr__", [](Type& t) { return t.annotation_str(); })
816 .def(
817 "str",
818 [](Type& t) {
819 std::ostringstream s;
820 s << t;
821 return s.str();
822 })
823 .def(
824 "containedTypes",
825 [](Type& self) { return self.containedTypes().vec(); })
826 .def("kind", [](const Type& t) { return typeKindToString(t.kind()); })
827 .def(
828 "dim",
829 [](Type& t) {
830 auto vshape = t.expectRef<TensorType>().sizes();
831 return vshape.size() ? py::cast(*vshape.size())
832 : py::cast<py::none>(Py_None);
833 })
834 .def(
835 "undefined",
836 [](Type& t) {
837 auto undef = t.expectRef<TensorType>().undefined();
838 return undef.has_value() ? py::cast(*undef)
839 : py::cast<py::none>(Py_None);
840 })
841 .def(
842 "sizes",
843 [](Type& t) -> py::object {
844 if (auto ptt = t.expect<TensorType>()) {
845 if (auto cs = ptt->sizes().concrete_sizes()) {
846 return py::cast(*cs);
847 }
848 }
849 return py::none();
850 })
851 .def(
852 "symbolic_sizes",
853 [](Type& t) -> py::object {
854 if (auto ptt = t.expect<TensorType>()) {
855 auto ss = ptt->symbolic_sizes();
856 if (!ss.rank().has_value()) {
857 return py::none();
858 }
859
860 std::vector<int64_t> ss_vals;
861 for (size_t i = 0; i < *ss.rank(); ++i) {
862 ss_vals.push_back(ss.at(i).value());
863 }
864 return py::cast(ss_vals);
865 }
866 return py::none();
867 })
868 .def(
869 "with_sizes",
870 [](Type& t, std::optional<std::vector<std::optional<int64_t>>> sizes)
871 -> py::object {
872 auto ptt = t.expect<TensorType>();
873 if (!ptt) {
874 return py::none();
875 }
876 if (!sizes) {
877 return py::cast(ptt->withSymbolicShapes(c10::SymbolicShape()));
878 }
879 return py::cast(ptt->withSymbolicShapes(*sizes));
880 })
881 .def(
882 "varyingSizes",
883 [](Type& t) -> py::object {
884 if (auto ptt = t.expect<TensorType>()) {
885 if (auto s = ptt->sizes().sizes()) {
886 return py::cast(s.value());
887 }
888 }
889 return py::none();
890 })
891 .def(
892 "strides",
893 [](Type& t) -> py::object {
894 if (auto ptt = t.expect<TensorType>()) {
895 if (auto cs = ptt->strides().concrete_sizes()) {
896 return py::cast(*cs);
897 }
898 }
899 return py::none();
900 })
901 .def(
902 "contiguous",
903 [](Type& t) {
904 return std::static_pointer_cast<Type>(
905 t.expectRef<TensorType>().contiguous());
906 })
907 .def(
908 "scalarType",
909 [](Type& t) {
910 auto scalar_type = t.expectRef<TensorType>().scalarType();
911 return (scalar_type) ? toString(*scalar_type) : nullptr;
912 })
913 .def(
914 "device",
915 [](Type& t) -> py::object {
916 auto device = t.expectRef<TensorType>().device();
917 if (!device) {
918 return py::none();
919 }
920 PyObject* thp_device = THPDevice_New(device.value());
921 return py::reinterpret_borrow<py::object>(thp_device);
922 // return toPyObject(device.value());
923 })
924 .def(
925 "with_device",
926 [](Type& t, py::object device) -> py::object {
927 at::Device c_device =
928 python::detail::py_object_to_device(std::move(device));
929 if (auto ptt = t.expect<TensorType>()) {
930 return py::cast(ptt->withDevice(c_device));
931 }
932 return py::none();
933 })
934 .def(
935 "dtype",
936 [](Type& t) -> py::object {
937 auto scalar_type = t.expectRef<TensorType>().scalarType();
938 if (!scalar_type) {
939 return py::none();
940 }
941 THPDtype* thp_dtype = torch::getTHPDtype(*scalar_type);
942 py::object dtype =
943 py::reinterpret_borrow<py::object>((PyObject*)thp_dtype);
944 return dtype;
945 })
946 .def(
947 "with_dtype",
948 [](Type& t, py::object dtype) -> py::object {
949 at::ScalarType scalar_type =
950 python::detail::py_object_to_dtype(std::move(dtype));
951
952 if (auto ptt = t.expect<TensorType>()) {
953 // auto scalar_type = dtype->scalar_type;
954 return py::cast(ptt->withScalarType(scalar_type));
955 }
956 return py::none();
957 })
958 .def(
959 "__eq__",
960 [](const TypePtr& self, const TypePtr& other) {
961 if (!other) {
962 return false;
963 }
964 return *self == *other;
965 })
966 .def(
967 "isSubtypeOf",
968 [](const TypePtr& self, const TypePtr& other) {
969 if (!other) {
970 return false;
971 }
972 return self->isSubtypeOf(other);
973 })
974 .def(
975 "is_interface_type",
976 [](const TypePtr& self) {
977 return self->castRaw<InterfaceType>() != nullptr;
978 })
979 .def(
980 "requires_grad",
981 [](const TypePtr& self) -> bool { return self->requires_grad(); })
982 .def_property_readonly(
983 "annotation_str", [](const std::shared_ptr<Type>& self) {
984 return self->annotation_str();
985 });
986
987 py::class_<AnyType, Type, AnyTypePtr>(m, "AnyType")
988 .def_static("get", &AnyType::get);
989 py::class_<NumberType, Type, NumberTypePtr>(m, "NumberType")
990 .def_static("get", &NumberType::get);
991 py::class_<IntType, Type, IntTypePtr>(m, "IntType")
992 .def_static("get", &IntType::get);
993 py::class_<SymIntType, Type, SymIntTypePtr>(m, "SymIntType")
994 .def_static("get", &SymIntType::get);
995 py::class_<SymBoolType, Type, SymBoolTypePtr>(m, "SymBoolType")
996 .def_static("get", &SymBoolType::get);
997 py::class_<FloatType, Type, FloatTypePtr>(m, "FloatType")
998 .def_static("get", &FloatType::get);
999 py::class_<ComplexType, Type, ComplexTypePtr>(m, "ComplexType")
1000 .def_static("get", &ComplexType::get);
1001 py::class_<TensorType, Type, TensorTypePtr>(m, "TensorType")
1002 .def_static("get", &TensorType::get)
1003 .def_static("getInferred", &TensorType::getInferred)
1004 .def_static("create_from_tensor", [](const at::Tensor& t) {
1005 return TensorType::create(t);
1006 });
1007 py::class_<BoolType, Type, BoolTypePtr>(m, "BoolType")
1008 .def_static("get", &BoolType::get);
1009 py::class_<StringType, Type, StringTypePtr>(m, "StringType")
1010 .def_static("get", &StringType::get);
1011 py::class_<DeviceObjType, Type, DeviceObjTypePtr>(m, "DeviceObjType")
1012 .def_static("get", &DeviceObjType::get);
1013 // TODO(antoniojkim): Add GeneratorType to the public API once its been added
1014 // to the public documentation
1015 py::class_<GeneratorType, Type, GeneratorTypePtr>(m, "_GeneratorType")
1016 .def_static("get", &GeneratorType::get);
1017 py::class_<StreamObjType, Type, StreamObjTypePtr>(m, "StreamObjType")
1018 .def_static("get", &StreamObjType::get);
1019 py::class_<PyObjectType, Type, PyObjectTypePtr>(m, "PyObjectType")
1020 .def_static("get", &PyObjectType::get);
1021 py::class_<NoneType, Type, NoneTypePtr>(m, "NoneType")
1022 .def_static("get", &NoneType::get);
1023
1024 py::class_<TupleType, Type, TupleTypePtr>(m, "TupleType")
1025 .def(py::init([](std::vector<TypePtr> a) {
1026 return TupleType::create(std::move(a));
1027 }))
1028 .def("elements", [](TupleType& self) {
1029 std::vector<TypePtr> types;
1030 for (const auto& type : self.elements()) {
1031 types.push_back(type);
1032 }
1033 return types;
1034 });
1035 py::class_<UnionType, Type, UnionTypePtr>(m, "UnionType")
1036 .def(py::init(
1037 [](const std::vector<TypePtr>& a) { return UnionType::create(a); }));
1038 py::class_<ListType, Type, ListTypePtr>(m, "ListType")
1039 .def(py::init([](const TypePtr& a) { return ListType::create(a); }))
1040 .def_static("ofInts", &ListType::ofInts)
1041 .def_static("ofTensors", &ListType::ofTensors)
1042 .def_static("ofFloats", &ListType::ofFloats)
1043 .def_static("ofComplexDoubles", &ListType::ofComplexDoubles)
1044 .def_static("ofBools", &ListType::ofBools)
1045 .def_static("ofStrings", &ListType::ofStrings)
1046 .def("getElementType", &ListType::getElementType);
1047 py::class_<DictType, Type, DictTypePtr>(m, "DictType")
1048 .def(py::init([](TypePtr key, TypePtr value) {
1049 return DictType::create(std::move(key), std::move(value));
1050 }))
1051 .def("getKeyType", &DictType::getKeyType)
1052 .def("getValueType", &DictType::getValueType);
1053 py::class_<OptionalType, Type, OptionalTypePtr>(m, "OptionalType")
1054 .def(py::init([](const TypePtr& a) { return OptionalType::create(a); }))
1055 .def_static("ofTensor", &OptionalType::ofTensor)
1056 .def("getElementType", &OptionalType::getElementType);
1057 py::class_<RRefType, Type, RRefTypePtr>(m, "RRefType")
1058 .def(py::init([](TypePtr a) { return RRefType::create(std::move(a)); }))
1059 .def("getElementType", &RRefType::getElementType);
1060
1061 py::class_<FutureType, Type, FutureTypePtr>(m, "FutureType")
1062 .def(py::init([](TypePtr a) { return FutureType::create(std::move(a)); }))
1063 .def("getElementType", &FutureType::getElementType);
1064
1065 py::class_<AwaitType, Type, AwaitTypePtr>(m, "AwaitType")
1066 .def(py::init([](TypePtr a) { return AwaitType::create(std::move(a)); }))
1067 .def("getElementType", &AwaitType::getElementType);
1068
1069 py::class_<ClassType, Type, ClassTypePtr>(m, "ClassType")
1070 .def(py::init([](const std::string& qualified_name) {
1071 return get_python_cu()->get_class(c10::QualifiedName(qualified_name));
1072 }))
1073 .def("name", [](ClassType& self) { return self.name()->name(); })
1074 .def("qualified_name", [](ClassType& self) {
1075 return self.name()->qualifiedName();
1076 });
1077 py::class_<EnumType, Type, EnumTypePtr>(m, "EnumType")
1078 .def(py::init([](const std::string& qualified_name,
1079 TypePtr value_type,
1080 const std::vector<py::object>& enum_names_values) {
1081 std::vector<std::pair<std::string, IValue>> names_values;
1082 names_values.reserve(enum_names_values.size());
1083 for (const auto& enum_name_value : enum_names_values) {
1084 auto enum_name = py::cast<std::string>(enum_name_value.attr("name"));
1085 auto enum_value = toIValue(enum_name_value.attr("value"), value_type);
1086 names_values.emplace_back(enum_name, enum_value);
1087 }
1088 return EnumType::create(
1089 c10::QualifiedName(qualified_name),
1090 std::move(value_type),
1091 std::move(names_values),
1092 get_python_cu());
1093 }));
1094 py::class_<InterfaceType, Type, InterfaceTypePtr>(m, "InterfaceType")
1095 .def(py::init([](const std::string& qualified_name) {
1096 return get_python_cu()->get_interface(
1097 c10::QualifiedName(qualified_name));
1098 }))
1099 .def(
1100 "getMethod",
1101 [](InterfaceType& self, const std::string& name) {
1102 return self.getMethod(name);
1103 },
1104 py::return_value_policy::reference)
1105 .def("getMethodNames", [](InterfaceType& self) {
1106 std::vector<std::string> names;
1107 for (const FunctionSchema& fn : self.methods()) {
1108 names.emplace_back(fn.name());
1109 }
1110 return names;
1111 });
1112 using ::c10::InferredType;
1113 py::class_<InferredType, std::shared_ptr<InferredType>>(m, "InferredType")
1114 .def(py::init([](std::shared_ptr<Type> type) {
1115 return std::make_shared<InferredType>(std::move(type));
1116 }))
1117 .def(py::init([](std::string reason) {
1118 return std::make_shared<InferredType>(std::move(reason));
1119 }))
1120 .def(
1121 "type",
1122 [](const std::shared_ptr<InferredType>& self) {
1123 return self->type();
1124 })
1125 .def(
1126 "success",
1127 [](const std::shared_ptr<InferredType>& self) {
1128 return self->success();
1129 })
1130 .def("reason", [](const std::shared_ptr<InferredType>& self) {
1131 return self->reason();
1132 });
1133
1134 py::class_<Use>(m, "Use")
1135 .def_readonly("user", &Use::user)
1136 .def_readonly("offset", &Use::offset)
1137 .def("isAfter", [](Use& self, Use& other_use) {
1138 return isBeforeOrAfter(self, other_use, false);
1139 });
1140
1141 py::class_<torch::jit::ShapeComputeGraphMapping>(
1142 m, "_ShapeComputeGraphMapping")
1143 .def(
1144 "partial_eval_shape_graph",
1145 [](ShapeComputeGraphMapping& g) {
1146 return g.partial_eval_shape_graph;
1147 })
1148 .def(
1149 "graph_output_to_symbolic_shape_dim",
1150 [](ShapeComputeGraphMapping& g) {
1151 return g.graph_output_to_symbolic_shape_dim_;
1152 });
1153 }
1154 } // namespace torch::jit
1155