xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/python/python_ir.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/python/python_ir.h>
2 
3 #include <ATen/core/jit_type.h>
4 #include <pybind11/pybind11.h>
5 #include <torch/csrc/Device.h>
6 #include <torch/csrc/Dtype.h>
7 #include <torch/csrc/api/include/torch/python.h>
8 #include <torch/csrc/jit/ir/alias_analysis.h>
9 #include <torch/csrc/jit/ir/ir.h>
10 #include <torch/csrc/jit/passes/canonicalize.h>
11 #include <torch/csrc/jit/passes/onnx/helper.h>
12 #include <torch/csrc/jit/passes/shape_analysis.h>
13 #include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
14 #include <torch/csrc/jit/python/pybind.h>
15 #include <torch/csrc/jit/python/python_tracer.h>
16 #include <torch/csrc/jit/runtime/argument_spec.h>
17 #include <torch/csrc/jit/serialization/export.h>
18 #include <torch/csrc/jit/serialization/python_print.h>
19 #include <torch/csrc/python_headers.h>
20 #include <torch/csrc/utils/pybind.h>
21 #include <torch/csrc/utils/python_strings.h>
22 #include <iostream>
23 #include <sstream>
24 #include <utility>
25 
26 namespace torch::jit {
27 
28 // Controls whether graph source ranges are printed by default
29 bool global_print_source_ranges = true;
30 
31 Symbol ConcretePythonOp::Kind = prim::PythonOp;
32 
33 using c10::Type;
34 
getPythonName(const PyObject * obj_)35 std::string getPythonName(const PyObject* obj_) {
36   pybind11::gil_scoped_acquire gil;
37   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
38   PyObject* obj = const_cast<PyObject*>(obj_);
39   auto v = py::getattr(obj, "__name__", py::str("<python_value>"));
40   // if this was a autograd.Function recover the name of the class
41   return py::str(v);
42 }
43 
printPyObject(std::ostream & out,const THPObjectPtr & obj)44 std::ostream& printPyObject(std::ostream& out, const THPObjectPtr& obj) {
45   pybind11::gil_scoped_acquire gil;
46   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
47   auto pyobj = py::handle(const_cast<PyObject*>(obj.get()));
48   if (py::isinstance<py::tuple>(pyobj)) {
49     // This special-case for printing tuples handles a problem where
50     // str((2L, 3L)) outputs "(2L, 3L)" in Python 2 but "(2, 3)"
51     // in Python 3.  In order to suppress the L-suffix, we must
52     // manually print the string ourselves, calling str() on the
53     // sub-elements.
54     //
55     // This is a fairly fragile fix (What if you have nested tuples
56     // in tuples? What if you have dictionaries?) but it seems to hit
57     // the cases that are triggered in practice in onnx-pytorch.  Revisit
58     // this code if this is not the case.
59     //
60     // By the way, one non-solution for this problem is to monkeypatch
61     // tuple.__str__; this doesn't work because Python doesn't allow
62     // monkeypatching methods of built-in types.
63     auto pytuple = pyobj.cast<py::tuple>();
64     out << "(";
65     size_t i = 0;
66     for (const auto& o : pytuple) {
67       if (i > 0) {
68         out << ", ";
69       }
70       THPObjectPtr str(py::str(o).release().ptr());
71       out << THPUtils_unpackString(str.get());
72       i++;
73     }
74     if (i == 1) {
75       out << ",";
76     }
77     out << ")";
78     return out;
79   } else {
80     return out << THPUtils_unpackString(py::str(pyobj).ptr());
81   }
82 }
83 
findNode(c10::ArrayRef<torch::jit::Block * > blocks,Symbol kind,bool recurse=true)84 Node* findNode(
85     c10::ArrayRef<torch::jit::Block*> blocks,
86     Symbol kind,
87     bool recurse = true) {
88   for (Block* block : blocks) {
89     for (Node* n : block->nodes()) {
90       if (n->kind() == kind) {
91         return n;
92       }
93       if (recurse) {
94         auto node = findNode(n->blocks(), kind, recurse);
95         if (node != nullptr) {
96           return node;
97         }
98       }
99     }
100   }
101   return nullptr;
102 }
103 
findNode(Block * block,Symbol kind,bool recurse=true)104 Node* findNode(Block* block, Symbol kind, bool recurse = true) {
105   std::vector<Block*> blocks = {block};
106   return findNode(blocks, kind, recurse);
107 }
108 
name() const109 std::string ConcretePythonOp::name() const {
110   pybind11::gil_scoped_acquire gil;
111   if (auto autograd = autogradFunction()) {
112     return getPythonName(autograd->get());
113   } else {
114     return getPythonName(pyobj.get());
115   }
116 }
117 
cloneFrom(Node * other_)118 void ConcretePythonOp::cloneFrom(Node* other_) {
119   // NOLINTNEXTLINE(bugprone-parent-virtual-call)
120   Node::cloneFrom(other_);
121   auto other = other_->cast<ConcretePythonOp>();
122   this->cconv = other->cconv;
123   Py_INCREF(other->pyobj.get());
124   this->pyobj = THPObjectPtr(other->pyobj.get());
125   for (auto& sa : other->scalar_args) {
126     Py_INCREF(sa.get());
127     this->scalar_args.emplace_back(sa.get());
128   }
129 }
130 
131 // recover the autograd.Function instance, if this PythonOp's function
132 // was originally SomeFunction.apply
133 // used in ONNX for discovering symbolics
autogradFunction() const134 std::optional<THPObjectPtr> ConcretePythonOp::autogradFunction() const {
135   pybind11::gil_scoped_acquire gil;
136   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
137   py::handle obj = const_cast<PyObject*>(pyobj.get());
138 
139   auto r = py::getattr(obj, "__self__", py::none());
140   if (r.is_none())
141     return std::nullopt;
142 
143   auto apply = py::getattr(r, "apply", py::none());
144   if (apply.is_none())
145     return std::nullopt;
146 
147   auto c = PyObject_RichCompareBool(apply.ptr(), obj.ptr(), Py_NE);
148   if (PyErr_Occurred())
149     throw py::error_already_set();
150   if (c)
151     return std::nullopt;
152 
153   return THPObjectPtr(r.release().ptr());
154 }
155 
writeScalars(std::ostream & out) const156 void ConcretePythonOp::writeScalars(std::ostream& out) const {
157   out << "(";
158   int i = 0;
159   for (auto& scalar : scalar_args) {
160     if (i++ > 0)
161       out << ", ";
162     printPyObject(out, scalar);
163   }
164   out << ")";
165 }
166 
lint_python() const167 void ConcretePythonOp::lint_python() const {
168   size_t n_scalars = 0, n_tensors = 0;
169   for (auto c : cconv) {
170     if (c == 'c') {
171       n_scalars++;
172     } else if (c == 'd') {
173       n_tensors++;
174     } else {
175       AT_ASSERT(0);
176     }
177     AT_ASSERT(static_cast<bool>(pyobj));
178   }
179   AT_ASSERT(n_scalars == scalar_args.size());
180   AT_ASSERT(n_tensors == inputs().size());
181 }
182 
createPythonOp(THPObjectPtr && pyobj,const std::string & cconv,pyobj_list && scalar_args)183 Node* Graph::createPythonOp(
184     THPObjectPtr&& pyobj,
185     const std::string& cconv,
186     pyobj_list&& scalar_args) {
187   ConcretePythonOp* op = new ConcretePythonOp(this);
188   return op->init(std::move(pyobj), cconv, std::move(scalar_args));
189 }
190 
initPythonIRBindings(PyObject * module_)191 void initPythonIRBindings(PyObject* module_) {
192   auto m = py::handle(module_).cast<py::module>();
193 
194   py::class_<AliasDb, std::shared_ptr<AliasDb>>(m, "AliasDb")
195       .def("dump", &AliasDb::dump)
196       .def("to_graphviz_str", &AliasDb::toGraphviz)
197       .def(
198           "may_contain_alias",
199           [&](AliasDb& db, Value* v1, Value* v2) {
200             return db.mayContainAlias(v1, v2);
201           })
202       .def(
203           "has_writers",
204           [&](AliasDb& db, Value* v1) { return db.hasWriters(v1); })
205       .def("__str__", &AliasDb::toString)
206       .def(
207           "move_after_topologically_valid",
208           [](AliasDb& db, Node* n, Node* movePoint) {
209             return db.moveAfterTopologicallyValid(n, movePoint);
210           })
211       .def(
212           "move_before_topologically_valid",
213           [](AliasDb& db, Node* n, Node* movePoint) {
214             return db.moveBeforeTopologicallyValid(n, movePoint);
215           });
216 #define GS(name) def(#name, &Graph ::name)
217   py::class_<Graph, std::shared_ptr<Graph>>(m, "Graph")
218       .def(py::init<>())
219       .def(
220           "__repr__",
221           [&](Graph& g) { return g.toString(global_print_source_ranges); })
222       .def("str", &Graph::toString, py::arg("print_source_ranges") = true)
223       .def_readonly_static(
224           "global_print_source_ranges", &global_print_source_ranges)
225       .def_static(
226           "set_global_print_source_ranges",
227           [&](const bool enabled) { global_print_source_ranges = enabled; },
228           py::arg("enabled") = true)
229       .def(
230           "alias_db",
231           [](std::shared_ptr<Graph> g,
232              bool isFrozen = false,
233              bool descend_function_calls = false) {
234             return std::make_shared<AliasDb>(
235                 std::move(g), isFrozen, descend_function_calls);
236           },
237           py::arg("isFrozen") = false,
238           py::arg("descend_function_calls") = false)
239       .def(
240           "dump_alias_db",
241           [](std::shared_ptr<Graph> g) {
242             AliasDb db(std::move(g));
243             db.dump();
244           })
245       .def(
246           "_export_onnx",
247           [](const std::shared_ptr<Graph>& g,
248              const std::map<std::string, at::Tensor>& initializers,
249              int64_t onnx_opset_version,
250              const std::unordered_map<
251                  std::string,
252                  std::unordered_map<int64_t, std::string>>& dynamic_axes,
253              bool defer_weight_export,
254              ::torch::onnx::OperatorExportTypes operator_export_type,
255              bool strip_doc_string,
256              bool keep_initializers_as_inputs,
257              const std::map<std::string, int>& custom_opsets,
258              bool add_node_names,
259              const std::string& onnx_file_path,
260              const NodeAttrNameMap& node_attr_to_name) {
261             std::string graph;
262             auto
263                 [model_proto,
264                  export_map,
265                  symbol_map,
266                  val_use_external_data_format,
267                  onnx_node_names] =
268                     export_onnx(
269                         g,
270                         initializers,
271                         onnx_opset_version,
272                         dynamic_axes,
273                         defer_weight_export,
274                         operator_export_type,
275                         strip_doc_string,
276                         keep_initializers_as_inputs,
277                         custom_opsets,
278                         add_node_names,
279                         false,
280                         onnx_file_path,
281                         node_attr_to_name);
282             std::unordered_map<std::string, py::bytes>
283                 python_serialized_export_map;
284             for (auto& kv : export_map) {
285               auto t = kv.second;
286               size_t copy_bytes = t.element_size() * t.numel();
287               // TODO: this is an unnecessary copy. In theory we can directly
288               // return the map from identifier to Tensor, but we need some API
289               // in Python to get raw `bytes` containing the raw tensor data.
290               python_serialized_export_map[kv.first] =
291                   py::bytes(static_cast<const char*>(t.data_ptr()), copy_bytes);
292             }
293             graph = serialize_model_proto_to_string(model_proto);
294             return std::make_tuple(
295                 py::bytes(graph),
296                 python_serialized_export_map,
297                 val_use_external_data_format,
298                 onnx_node_names);
299           },
300           py::arg("initializers"),
301           py::arg("onnx_opset_version") = 0,
302           py::arg("dynamic_axes"),
303           py::arg("defer_weight_export") = false,
304           py::arg("operator_export_type") =
305               ::torch::onnx::OperatorExportTypes::ONNX,
306           py::arg("strip_doc_string") = true,
307           py::arg("keep_initializers_as_inputs") = true,
308           py::arg("custom_opsets"),
309           py::arg("add_node_names") = true,
310           py::arg("onnx_file_path") = std::string(),
311           py::arg("node_attr_to_name") = NodeAttrNameMap())
312       .def(
313           "_pretty_print_onnx",
314           [](const std::shared_ptr<Graph>& g,
315              const std::map<std::string, at::Tensor>& initializers,
316              int64_t onnx_opset_version,
317              bool defer_weight_export,
318              ::torch::onnx::OperatorExportTypes operator_export_type,
319              bool google_printer,
320              bool keep_initializers_as_inputs,
321              const std::map<std::string, int>& custom_opsets,
322              bool add_node_names) {
323             return pretty_print_onnx(
324                 g,
325                 initializers,
326                 onnx_opset_version,
327                 defer_weight_export,
328                 operator_export_type,
329                 google_printer,
330                 keep_initializers_as_inputs,
331                 custom_opsets,
332                 add_node_names);
333           },
334           py::arg("initializers"),
335           py::arg("onnx_opset_version") = 0,
336           py::arg("defer_weight_export") = false,
337           py::arg("operator_export_type") =
338               ::torch::onnx::OperatorExportTypes::ONNX,
339           py::arg("google_printer") = false,
340           py::arg("keep_initializers_as_inputs") = true,
341           py::arg("custom_opsets"),
342           py::arg("add_node_names") = true)
343       .def(
344           "inputs",
345           [](Graph& g) {
346             return py::make_iterator(g.inputs().begin(), g.inputs().end());
347           },
348           py::keep_alive<0, 1>())
349       .def(
350           "outputs",
351           [](Graph& g) {
352             return py::make_iterator(g.outputs().begin(), g.outputs().end());
353           },
354           py::keep_alive<0, 1>())
355       // We keep the graph alive while the iterator lives. Destroying
356       // nodes might still be hazardous.
357       .def(
358           "nodes",
359           [](Graph& g) {
360             return py::make_iterator(g.nodes().begin(), g.nodes().end());
361           },
362           py::keep_alive<0, 1>())
363       .def(
364           "findNode",
365           [](Graph& g, const std::string& kind, bool recurse) {
366             return findNode(g.block(), Symbol::fromQualString(kind), recurse);
367           },
368           "Find Node",
369           py::arg("kind"),
370           py::arg("recurse") = true)
371       .def(
372           "findAllNodes",
373           [](Graph& g, const std::string& kind, bool recurse) {
374             return findAllNodes(g, Symbol::fromQualString(kind), recurse);
375           },
376           "Find all nodes",
377           py::arg("kind"),
378           py::arg("recurse") = true)
379       .def(
380           "addInput",
381           [](Graph& g, const std::string& name) { return g.addInput(name); },
382           "Add input to graph with optional name seed",
383           py::arg("name") = "")
384       .def("copy", [](Graph& g) { return g.copy(); })
385       .GS(eraseInput)
386       .GS(eraseOutput)
387       .GS(registerOutput)
388       .def(
389           "permuteInputs",
390           [](Graph& g, const std::vector<size_t>& new_inputs) {
391             g.block()->permuteInputs(new_inputs);
392           })
393       .def(
394           "create",
395           [](Graph& g, const char* str) {
396             return g.create(Symbol::fromQualString(str));
397           })
398       .def(
399           "create",
400           [](Graph& g, const char* str, size_t noutputs) {
401             return g.create(Symbol::fromQualString(str), noutputs);
402           })
403       .def(
404           "create",
405           [](Graph& g, const char* str, const std::vector<Value*>& inputs) {
406             TORCH_CHECK_VALUE(
407                 std::all_of(
408                     inputs.begin(),
409                     inputs.end(),
410                     [](Value* v) { return (v != nullptr); }),
411                 "cannot pass None in inputs");
412             return g.create(Symbol::fromQualString(str), inputs);
413           })
414       .def(
415           "create",
416           [](Graph& g,
417              const char* str,
418              const std::vector<Value*>& inputs,
419              size_t noutputs) {
420             TORCH_CHECK_VALUE(
421                 std::all_of(
422                     inputs.begin(),
423                     inputs.end(),
424                     [](Value* v) { return (v != nullptr); }),
425                 "cannot pass None in inputs");
426             return g.create(Symbol::fromQualString(str), inputs, noutputs);
427           })
428       .def("param_node", [](Graph& g) { return g.block()->param_node(); })
429       .def("return_node", [](Graph& g) { return g.block()->return_node(); })
430       .def(
431           "createFusionGroup",
432           [](Graph& g) { return g.createWithSubgraph(prim::FusionGroup); })
433       .def(
434           "createCudaFusionGroup",
435           [](Graph& g) { return g.createWithSubgraph(prim::CudaFusionGroup); })
436       .def(
437           "createClone",
438           [](Graph& g, Node* n, py::object fn) {
439             return g.createClone(
440                 n, [&](Value* e) { return fn(e).cast<Value*>(); });
441           })
442       .GS(appendNode)
443       .GS(prependNode)
444       // NB: insert_point_guard defined over direct modification of insert point
445       .def(
446           "insert_point_guard",
447           [](Graph& g, Node* n) {
448             return py::module::import("torch.jit._ir_utils")
449                 .attr("insert_point_guard")(g, n);
450           })
451       .def(
452           "insert_point_guard",
453           [](Graph& g, Block* b) {
454             return py::module::import("torch.jit._ir_utils")
455                 .attr("insert_point_guard")(g, b);
456           })
457       .GS(insertPoint)
458       .def("setInsertPoint", [](Graph& g, Node* n) { g.setInsertPoint(n); })
459       .def("setInsertPoint", [](Graph& g, Block* n) { g.setInsertPoint(n); })
460       .def(
461           "insertGraph",
462           [](Graph& g, Graph& callee, const std::vector<Value*>& inputs) {
463             return insertGraph(g, callee, inputs);
464           })
465       .def(
466           "insertGraph",
467           [](Graph& g,
468              Graph& callee,
469              const std::vector<Value*>& inputs,
470              std::unordered_map<Value*, Value*> value_map) {
471             return insertGraph(g, callee, inputs, value_map);
472           })
473       .def(
474           "insert",
475           [](Graph& g, Symbol opname, const std::vector<Value*>& args) {
476             std::vector<NamedValue> args_named;
477             args_named.reserve(args.size());
478             for (Value* v : args) {
479               args_named.emplace_back(v);
480             }
481             return g.insert(opname, args_named);
482           })
483       .def(
484           "makeMultiOutputIntoTuple",
485           [](Graph& g) {
486             auto tup = g.createTuple(g.outputs());
487             tup->insertBefore(g.return_node());
488             for (int64_t i = static_cast<int64_t>(g.outputs().size()) - 1;
489                  i >= 0;
490                  i--) {
491               g.eraseOutput(0);
492             }
493             g.registerOutput(tup->output());
494           })
495       .def(
496           "insertConstant",
497           [](Graph& g, const IValue& ival) { return g.insertConstant(ival); })
498       .GS(lint)
499       .def("block", [](Graph& g) { return g.block(); })
500       .GS(insertNode);
501 #undef GS
502 
503 #define VS(name) def(#name, &Value ::name)
504   py::class_<Value, unwrapping_shared_ptr<Value>>(m, "Value")
505       .def(
506           "__repr__",
507           [](Value& n) {
508             std::stringstream ss;
509             ss << n.debugName() << " defined in (" << *n.node() << ")";
510             return ss.str();
511           })
512       .VS(type)
513       .VS(setType)
514       .def(
515           "inferTypeFrom",
516           py::overload_cast<const at::Tensor&>(&Value::inferTypeFrom))
517       .def(
518           "inferTypeFrom",
519           py::overload_cast<const c10::intrusive_ptr<c10::ivalue::Object>&>(
520               &Value::inferTypeFrom))
521       // skip owningGraph because it returns a raw pointer to a otherwise
522       // std::shared_ptr stored graph object, and would cause a double free
523       .VS(unique)
524       .VS(debugName)
525       .VS(setDebugName)
526       .VS(offset)
527       .VS(uses)
528       .VS(replaceAllUsesWith)
529       .VS(replaceAllUsesAfterNodeWith)
530       .def("node", [](Value& v) { return v.node(); })
531       .def(
532           "setTypeAs",
533           [](Value* node, Value* other) {
534             node->setType(other->type());
535             return node;
536           })
537       .VS(copyMetadata)
538       .VS(isCompleteTensor)
539       .VS(requires_grad)
540       .def(
541           "requiresGrad",
542           [](Value& n) {
543             return n.type()->expectRef<TensorType>().requiresGrad();
544           })
545       .def("toIValue", [](Value& n) { return toIValue(&n); })
546       .def("type", [](Value& v) { return v.type(); });
547 #undef VS
548 
549   py::class_<Block, unwrapping_shared_ptr<Block>>(m, "Block")
550       .def(
551           "nodes",
552           [](Block& b) {
553             return py::make_iterator(b.nodes().begin(), b.nodes().end());
554           })
555       .def(
556           "findNode",
557           [](Block& b, const std::string& kind, bool recurse) {
558             return findNode(&b, Symbol::fromQualString(kind), recurse);
559           },
560           "Find Node",
561           py::arg("kind"),
562           py::arg("recurse") = true)
563       .def(
564           "findAllNodes",
565           [](Block& b, const std::string& kind, bool recurse) {
566             return findAllNodes(b, Symbol::fromQualString(kind), recurse);
567           },
568           "Find all nodes",
569           py::arg("kind"),
570           py::arg("recurse") = true)
571       .def(
572           "inputs",
573           [](Block& b) {
574             return py::make_iterator(b.inputs().begin(), b.inputs().end());
575           })
576       .def(
577           "outputs",
578           [](Block& b) {
579             return py::make_iterator(b.outputs().begin(), b.outputs().end());
580           })
581       .def("returnNode", [](Block& b) { return b.return_node(); })
582       .def("paramNode", [](Block& b) { return b.param_node(); })
583       .def("owningNode", [](Block& b) { return b.owningNode(); })
584       .def(
585           "addNode",
586           [](Block& b, const char* str, const std::vector<Value*>& inputs) {
587             return addNodeToBlock(&b, Symbol::fromQualString(str), inputs);
588           })
589       .def("addInputToBlock", [](Block& b) { return addInputToBlock(&b); })
590       .def("registerOutput", [](Block& b, Value* value) {
591         return b.registerOutput(value);
592       });
593 
594 #define NS(name) def(#name, &Node ::name)
595   py::class_<Node, unwrapping_shared_ptr<Node>>(m, "Node")
596       .def(
597           "__repr__",
598           [](Node& n) {
599             std::stringstream ss;
600             ss << n;
601             return ss.str();
602           })
603       .def("sourceRange", [](Node& n) { return n.sourceRange().str(); })
604       .def("hasMultipleOutputs", [](Node& n) { return n.outputs().size() > 1; })
605       .def("inputsSize", [](Node& n) { return n.inputs().size(); })
606       .def("outputsSize", [](Node& n) { return n.outputs().size(); })
607       .NS(kind)
608       .def("prev", [](Node& n) { return n.prev(); })
609       .def("matches", [](Node& n, const char* s) { return n.matches(s); })
610       .def("owningBlock", [](Node& n) { return n.owningBlock(); })
611       .def("inputsAt", [](Node& n, size_t i) { return n.inputs().at(i); })
612       .def(
613           "inputs",
614           [](Node& n) {
615             return py::make_iterator(n.inputs().begin(), n.inputs().end());
616           })
617       .def(
618           "schema",
619           [](Node& n) {
620             std::stringstream ss;
621             if (n.maybeSchema()) {
622               ss << n.schema();
623             } else {
624               ss << "(no schema)";
625             }
626             return ss.str();
627           })
628       .def(
629           "outputs",
630           [](Node& n) {
631             return py::make_iterator(n.outputs().begin(), n.outputs().end());
632           })
633       .def("outputsAt", [](Node& n, size_t i) { return n.outputs().at(i); })
634       .def(
635           "findNode",
636           [](Node& n, const std::string& kind, bool recurse) {
637             return findNode(n.blocks(), Symbol::fromQualString(kind), recurse);
638           },
639           "Find Node",
640           py::arg("kind"),
641           py::arg("recurse") = true)
642       .def(
643           "findAllNodes",
644           [](Node& n, const std::string& kind, bool recurse) {
645             return findAllNodes(
646                 n.blocks(), Symbol::fromQualString(kind), recurse);
647           },
648           "Find all nodes",
649           py::arg("kind"),
650           py::arg("recurse") = true)
651       .def("input", [](Node& n) { return n.input(); })
652       .def("output", [](Node& n) { return n.output(); })
653       .def(
654           "getModuleHierarchy",
655           [](Node& n) { return torch::jit::utils::getNodesModuleHierarchy(n); })
656       .def(
657           "namedInput",
658           [](Node& n, const std::string& unqualName) {
659             return n.namedInput(unqualName);
660           })
661       .NS(addInput)
662       .NS(copyMetadata)
663       .NS(replaceInput)
664       .NS(replaceInputWith)
665       .NS(replaceAllUsesWith)
666       .NS(insertBefore)
667       .NS(insertAfter)
668       .NS(isBefore)
669       .NS(isAfter)
670       .NS(moveAfter)
671       .NS(moveBefore)
672       .NS(removeInput)
673       .NS(removeAllInputs)
674       .NS(destroy)
675       .NS(hasUses)
676       .NS(eraseOutput)
677       .NS(addOutput)
678       .NS(scopeName)
679       .NS(isNondeterministic)
680       .def(
681           "blocks",
682           [](Node& n) {
683             return py::make_iterator(n.blocks().begin(), n.blocks().end());
684           })
685       .NS(addBlock)
686       .NS(mustBeNone)
687 
688 #define AS(name) def(#name, &Node::name)
689       // methods from Attributes
690       .AS(copyAttributes)
691       .AS(hasAttributes)
692 #undef AS
693 #define AS(name) def(#name, &Node::name##S)
694       // The default method names take Symbol, but the string conversion for
695       // Symbol you to qualify with attr::. This is not very user friendly
696       // for attributes, so expose the string variants instead.
697       .AS(hasAttribute)
698       .AS(kindOf)
699       .AS(removeAttribute)
700       .AS(attributeNames)
701 #undef AS
702 #define CREATE_ACCESSOR(Kind, method)                                       \
703   def(#method "_", [](Node& n, const char* name, Kind##Attr::ValueType v) { \
704     return n.method##_(Symbol::attr(name), std::move(v));                   \
705   }).def(#method, [](Node& n, const char* name) {                           \
706     return n.method(Symbol::attr(name));                                    \
707   })
708       .CREATE_ACCESSOR(Float, f)
709       .CREATE_ACCESSOR(Floats, fs)
710       .CREATE_ACCESSOR(Complex, c)
711       .CREATE_ACCESSOR(String, s)
712       .CREATE_ACCESSOR(Strings, ss)
713       .CREATE_ACCESSOR(Int, i)
714       .CREATE_ACCESSOR(Ints, is)
715       .CREATE_ACCESSOR(Graph, g)
716       .CREATE_ACCESSOR(Graphs, gs)
717       .CREATE_ACCESSOR(IValue, ival)
718 #undef CREATE_ACCESSOR
719       // Tensor (t_) -- manually written to unwrap the variable into a tensor.
720       .def(
721           "t_",
722           [](Node& n, const char* name, const torch::autograd::Variable& v) {
723             AT_ASSERT(!v.requires_grad());
724             return n.t_(Symbol::attr(name), v);
725           })
726       .def(
727           "t",
728           [](Node& n, const char* name) { return n.t(Symbol::attr(name)); })
729       // Tensors (ts_) -- manually written to unwrap variables into tensors.
730       .def(
731           "ts_",
732           [](Node& n,
733              const char* name,
734              const std::vector<torch::autograd::Variable>& vs) {
735             std::vector<at::Tensor> tensors;
736             tensors.reserve(vs.size());
737             for (auto& variable : vs) {
738               AT_ASSERT(!variable.requires_grad());
739               tensors.push_back(variable);
740             }
741             return n.ts_(Symbol::attr(name), std::move(tensors));
742           })
743       .def(
744           "ts",
745           [](Node& n, const char* name) {
746             auto tensors = n.ts(Symbol::attr(name));
747             std::vector<torch::autograd::Variable> variables;
748             variables.reserve(tensors.size());
749             for (auto& tensor : tensors) {
750               variables.emplace_back(std::move(tensor));
751             }
752             return variables;
753           })
754       .def(
755           "z_",
756           [](Node& n, const char* name, const at::Tensor& v) {
757             return n.t_(
758                 Symbol::attr(name),
759                 autograd::Variable(v.view(std::vector<int64_t>{}))
760                     .set_requires_grad(false));
761           })
762       .def(
763           "z",
764           [](Node& n, const char* name) { return n.t(Symbol::attr(name)); })
765       .def(
766           "ty_",
767           [](Node& n, const char* name, const TypePtr& type) {
768             return n.ty_(Symbol::attr(name), type);
769           })
770       .def(
771           "ty",
772           [](Node& n, const char* name) { return n.ty(Symbol::attr(name)); })
773       .def(
774           "tys_",
775           [](Node& n, const char* name, const std::vector<TypePtr>& types) {
776             return n.tys_(Symbol::attr(name), types);
777           })
778       .def(
779           "tys",
780           [](Node& n, const char* name) { return n.tys(Symbol::attr(name)); })
781       .def(
782           "zs_",
783           [](Node& n, const char* name, TensorsAttr::ValueType v) {
784             for (auto& i : v) {
785               i = autograd::Variable(i.view(std::vector<int64_t>{}))
786                       .set_requires_grad(false);
787             }
788             return n.ts_(Symbol::attr(name), std::move(v));
789           })
790       .def(
791           "zs",
792           [](Node& n, const char* name) { return n.ts(Symbol::attr(name)); })
793       .def(
794           "pyobj",
795           [](Node& n) {
796             return py::handle(n.expect<ConcretePythonOp>()->pyobj.get())
797                 .cast<py::object>();
798           })
799       .def("cconv", [](Node& n) { return n.expect<ConcretePythonOp>()->cconv; })
800       .def(
801           "pyname",
802           [](Node& n) { return n.expect<ConcretePythonOp>()->name(); })
803       .def("scalar_args", [](Node& n) {
804         auto op = n.expect<ConcretePythonOp>();
805         auto scalars = py::list();
806         auto append = scalars.attr("append");
807         for (auto& arg : op->scalar_args) {
808           append(py::handle(arg.get()));
809         }
810         return scalars;
811       });
812 
813   using ::c10::Type;
814   py::class_<Type, TypePtr>(m, "Type")
815       .def("__repr__", [](Type& t) { return t.annotation_str(); })
816       .def(
817           "str",
818           [](Type& t) {
819             std::ostringstream s;
820             s << t;
821             return s.str();
822           })
823       .def(
824           "containedTypes",
825           [](Type& self) { return self.containedTypes().vec(); })
826       .def("kind", [](const Type& t) { return typeKindToString(t.kind()); })
827       .def(
828           "dim",
829           [](Type& t) {
830             auto vshape = t.expectRef<TensorType>().sizes();
831             return vshape.size() ? py::cast(*vshape.size())
832                                  : py::cast<py::none>(Py_None);
833           })
834       .def(
835           "undefined",
836           [](Type& t) {
837             auto undef = t.expectRef<TensorType>().undefined();
838             return undef.has_value() ? py::cast(*undef)
839                                      : py::cast<py::none>(Py_None);
840           })
841       .def(
842           "sizes",
843           [](Type& t) -> py::object {
844             if (auto ptt = t.expect<TensorType>()) {
845               if (auto cs = ptt->sizes().concrete_sizes()) {
846                 return py::cast(*cs);
847               }
848             }
849             return py::none();
850           })
851       .def(
852           "symbolic_sizes",
853           [](Type& t) -> py::object {
854             if (auto ptt = t.expect<TensorType>()) {
855               auto ss = ptt->symbolic_sizes();
856               if (!ss.rank().has_value()) {
857                 return py::none();
858               }
859 
860               std::vector<int64_t> ss_vals;
861               for (size_t i = 0; i < *ss.rank(); ++i) {
862                 ss_vals.push_back(ss.at(i).value());
863               }
864               return py::cast(ss_vals);
865             }
866             return py::none();
867           })
868       .def(
869           "with_sizes",
870           [](Type& t, std::optional<std::vector<std::optional<int64_t>>> sizes)
871               -> py::object {
872             auto ptt = t.expect<TensorType>();
873             if (!ptt) {
874               return py::none();
875             }
876             if (!sizes) {
877               return py::cast(ptt->withSymbolicShapes(c10::SymbolicShape()));
878             }
879             return py::cast(ptt->withSymbolicShapes(*sizes));
880           })
881       .def(
882           "varyingSizes",
883           [](Type& t) -> py::object {
884             if (auto ptt = t.expect<TensorType>()) {
885               if (auto s = ptt->sizes().sizes()) {
886                 return py::cast(s.value());
887               }
888             }
889             return py::none();
890           })
891       .def(
892           "strides",
893           [](Type& t) -> py::object {
894             if (auto ptt = t.expect<TensorType>()) {
895               if (auto cs = ptt->strides().concrete_sizes()) {
896                 return py::cast(*cs);
897               }
898             }
899             return py::none();
900           })
901       .def(
902           "contiguous",
903           [](Type& t) {
904             return std::static_pointer_cast<Type>(
905                 t.expectRef<TensorType>().contiguous());
906           })
907       .def(
908           "scalarType",
909           [](Type& t) {
910             auto scalar_type = t.expectRef<TensorType>().scalarType();
911             return (scalar_type) ? toString(*scalar_type) : nullptr;
912           })
913       .def(
914           "device",
915           [](Type& t) -> py::object {
916             auto device = t.expectRef<TensorType>().device();
917             if (!device) {
918               return py::none();
919             }
920             PyObject* thp_device = THPDevice_New(device.value());
921             return py::reinterpret_borrow<py::object>(thp_device);
922             // return toPyObject(device.value());
923           })
924       .def(
925           "with_device",
926           [](Type& t, py::object device) -> py::object {
927             at::Device c_device =
928                 python::detail::py_object_to_device(std::move(device));
929             if (auto ptt = t.expect<TensorType>()) {
930               return py::cast(ptt->withDevice(c_device));
931             }
932             return py::none();
933           })
934       .def(
935           "dtype",
936           [](Type& t) -> py::object {
937             auto scalar_type = t.expectRef<TensorType>().scalarType();
938             if (!scalar_type) {
939               return py::none();
940             }
941             THPDtype* thp_dtype = torch::getTHPDtype(*scalar_type);
942             py::object dtype =
943                 py::reinterpret_borrow<py::object>((PyObject*)thp_dtype);
944             return dtype;
945           })
946       .def(
947           "with_dtype",
948           [](Type& t, py::object dtype) -> py::object {
949             at::ScalarType scalar_type =
950                 python::detail::py_object_to_dtype(std::move(dtype));
951 
952             if (auto ptt = t.expect<TensorType>()) {
953               // auto scalar_type = dtype->scalar_type;
954               return py::cast(ptt->withScalarType(scalar_type));
955             }
956             return py::none();
957           })
958       .def(
959           "__eq__",
960           [](const TypePtr& self, const TypePtr& other) {
961             if (!other) {
962               return false;
963             }
964             return *self == *other;
965           })
966       .def(
967           "isSubtypeOf",
968           [](const TypePtr& self, const TypePtr& other) {
969             if (!other) {
970               return false;
971             }
972             return self->isSubtypeOf(other);
973           })
974       .def(
975           "is_interface_type",
976           [](const TypePtr& self) {
977             return self->castRaw<InterfaceType>() != nullptr;
978           })
979       .def(
980           "requires_grad",
981           [](const TypePtr& self) -> bool { return self->requires_grad(); })
982       .def_property_readonly(
983           "annotation_str", [](const std::shared_ptr<Type>& self) {
984             return self->annotation_str();
985           });
986 
987   py::class_<AnyType, Type, AnyTypePtr>(m, "AnyType")
988       .def_static("get", &AnyType::get);
989   py::class_<NumberType, Type, NumberTypePtr>(m, "NumberType")
990       .def_static("get", &NumberType::get);
991   py::class_<IntType, Type, IntTypePtr>(m, "IntType")
992       .def_static("get", &IntType::get);
993   py::class_<SymIntType, Type, SymIntTypePtr>(m, "SymIntType")
994       .def_static("get", &SymIntType::get);
995   py::class_<SymBoolType, Type, SymBoolTypePtr>(m, "SymBoolType")
996       .def_static("get", &SymBoolType::get);
997   py::class_<FloatType, Type, FloatTypePtr>(m, "FloatType")
998       .def_static("get", &FloatType::get);
999   py::class_<ComplexType, Type, ComplexTypePtr>(m, "ComplexType")
1000       .def_static("get", &ComplexType::get);
1001   py::class_<TensorType, Type, TensorTypePtr>(m, "TensorType")
1002       .def_static("get", &TensorType::get)
1003       .def_static("getInferred", &TensorType::getInferred)
1004       .def_static("create_from_tensor", [](const at::Tensor& t) {
1005         return TensorType::create(t);
1006       });
1007   py::class_<BoolType, Type, BoolTypePtr>(m, "BoolType")
1008       .def_static("get", &BoolType::get);
1009   py::class_<StringType, Type, StringTypePtr>(m, "StringType")
1010       .def_static("get", &StringType::get);
1011   py::class_<DeviceObjType, Type, DeviceObjTypePtr>(m, "DeviceObjType")
1012       .def_static("get", &DeviceObjType::get);
1013   // TODO(antoniojkim): Add GeneratorType to the public API once its been added
1014   //                    to the public documentation
1015   py::class_<GeneratorType, Type, GeneratorTypePtr>(m, "_GeneratorType")
1016       .def_static("get", &GeneratorType::get);
1017   py::class_<StreamObjType, Type, StreamObjTypePtr>(m, "StreamObjType")
1018       .def_static("get", &StreamObjType::get);
1019   py::class_<PyObjectType, Type, PyObjectTypePtr>(m, "PyObjectType")
1020       .def_static("get", &PyObjectType::get);
1021   py::class_<NoneType, Type, NoneTypePtr>(m, "NoneType")
1022       .def_static("get", &NoneType::get);
1023 
1024   py::class_<TupleType, Type, TupleTypePtr>(m, "TupleType")
1025       .def(py::init([](std::vector<TypePtr> a) {
1026         return TupleType::create(std::move(a));
1027       }))
1028       .def("elements", [](TupleType& self) {
1029         std::vector<TypePtr> types;
1030         for (const auto& type : self.elements()) {
1031           types.push_back(type);
1032         }
1033         return types;
1034       });
1035   py::class_<UnionType, Type, UnionTypePtr>(m, "UnionType")
1036       .def(py::init(
1037           [](const std::vector<TypePtr>& a) { return UnionType::create(a); }));
1038   py::class_<ListType, Type, ListTypePtr>(m, "ListType")
1039       .def(py::init([](const TypePtr& a) { return ListType::create(a); }))
1040       .def_static("ofInts", &ListType::ofInts)
1041       .def_static("ofTensors", &ListType::ofTensors)
1042       .def_static("ofFloats", &ListType::ofFloats)
1043       .def_static("ofComplexDoubles", &ListType::ofComplexDoubles)
1044       .def_static("ofBools", &ListType::ofBools)
1045       .def_static("ofStrings", &ListType::ofStrings)
1046       .def("getElementType", &ListType::getElementType);
1047   py::class_<DictType, Type, DictTypePtr>(m, "DictType")
1048       .def(py::init([](TypePtr key, TypePtr value) {
1049         return DictType::create(std::move(key), std::move(value));
1050       }))
1051       .def("getKeyType", &DictType::getKeyType)
1052       .def("getValueType", &DictType::getValueType);
1053   py::class_<OptionalType, Type, OptionalTypePtr>(m, "OptionalType")
1054       .def(py::init([](const TypePtr& a) { return OptionalType::create(a); }))
1055       .def_static("ofTensor", &OptionalType::ofTensor)
1056       .def("getElementType", &OptionalType::getElementType);
1057   py::class_<RRefType, Type, RRefTypePtr>(m, "RRefType")
1058       .def(py::init([](TypePtr a) { return RRefType::create(std::move(a)); }))
1059       .def("getElementType", &RRefType::getElementType);
1060 
1061   py::class_<FutureType, Type, FutureTypePtr>(m, "FutureType")
1062       .def(py::init([](TypePtr a) { return FutureType::create(std::move(a)); }))
1063       .def("getElementType", &FutureType::getElementType);
1064 
1065   py::class_<AwaitType, Type, AwaitTypePtr>(m, "AwaitType")
1066       .def(py::init([](TypePtr a) { return AwaitType::create(std::move(a)); }))
1067       .def("getElementType", &AwaitType::getElementType);
1068 
1069   py::class_<ClassType, Type, ClassTypePtr>(m, "ClassType")
1070       .def(py::init([](const std::string& qualified_name) {
1071         return get_python_cu()->get_class(c10::QualifiedName(qualified_name));
1072       }))
1073       .def("name", [](ClassType& self) { return self.name()->name(); })
1074       .def("qualified_name", [](ClassType& self) {
1075         return self.name()->qualifiedName();
1076       });
1077   py::class_<EnumType, Type, EnumTypePtr>(m, "EnumType")
1078       .def(py::init([](const std::string& qualified_name,
1079                        TypePtr value_type,
1080                        const std::vector<py::object>& enum_names_values) {
1081         std::vector<std::pair<std::string, IValue>> names_values;
1082         names_values.reserve(enum_names_values.size());
1083         for (const auto& enum_name_value : enum_names_values) {
1084           auto enum_name = py::cast<std::string>(enum_name_value.attr("name"));
1085           auto enum_value = toIValue(enum_name_value.attr("value"), value_type);
1086           names_values.emplace_back(enum_name, enum_value);
1087         }
1088         return EnumType::create(
1089             c10::QualifiedName(qualified_name),
1090             std::move(value_type),
1091             std::move(names_values),
1092             get_python_cu());
1093       }));
1094   py::class_<InterfaceType, Type, InterfaceTypePtr>(m, "InterfaceType")
1095       .def(py::init([](const std::string& qualified_name) {
1096         return get_python_cu()->get_interface(
1097             c10::QualifiedName(qualified_name));
1098       }))
1099       .def(
1100           "getMethod",
1101           [](InterfaceType& self, const std::string& name) {
1102             return self.getMethod(name);
1103           },
1104           py::return_value_policy::reference)
1105       .def("getMethodNames", [](InterfaceType& self) {
1106         std::vector<std::string> names;
1107         for (const FunctionSchema& fn : self.methods()) {
1108           names.emplace_back(fn.name());
1109         }
1110         return names;
1111       });
1112   using ::c10::InferredType;
1113   py::class_<InferredType, std::shared_ptr<InferredType>>(m, "InferredType")
1114       .def(py::init([](std::shared_ptr<Type> type) {
1115         return std::make_shared<InferredType>(std::move(type));
1116       }))
1117       .def(py::init([](std::string reason) {
1118         return std::make_shared<InferredType>(std::move(reason));
1119       }))
1120       .def(
1121           "type",
1122           [](const std::shared_ptr<InferredType>& self) {
1123             return self->type();
1124           })
1125       .def(
1126           "success",
1127           [](const std::shared_ptr<InferredType>& self) {
1128             return self->success();
1129           })
1130       .def("reason", [](const std::shared_ptr<InferredType>& self) {
1131         return self->reason();
1132       });
1133 
1134   py::class_<Use>(m, "Use")
1135       .def_readonly("user", &Use::user)
1136       .def_readonly("offset", &Use::offset)
1137       .def("isAfter", [](Use& self, Use& other_use) {
1138         return isBeforeOrAfter(self, other_use, false);
1139       });
1140 
1141   py::class_<torch::jit::ShapeComputeGraphMapping>(
1142       m, "_ShapeComputeGraphMapping")
1143       .def(
1144           "partial_eval_shape_graph",
1145           [](ShapeComputeGraphMapping& g) {
1146             return g.partial_eval_shape_graph;
1147           })
1148       .def(
1149           "graph_output_to_symbolic_shape_dim",
1150           [](ShapeComputeGraphMapping& g) {
1151             return g.graph_output_to_symbolic_shape_dim_;
1152           });
1153 }
1154 } // namespace torch::jit
1155