xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/ir/ir.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/ir/ir.h>
2 
3 #include <ATen/core/builtin_function.h>
4 #include <ATen/core/function.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/StringUtil.h>
7 #include <c10/util/irange.h>
8 #include <torch/csrc/jit/api/function_impl.h>
9 #include <torch/csrc/jit/frontend/error_report.h>
10 #include <torch/csrc/jit/frontend/schema_matching.h>
11 #include <torch/csrc/jit/ir/constants.h>
12 #include <torch/csrc/jit/runtime/operator.h>
13 #include <torch/csrc/jit/serialization/python_print.h>
14 
15 #include <algorithm>
16 #include <iostream>
17 #include <locale>
18 #include <memory>
19 #include <set>
20 #include <sstream>
21 #include <string>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 
26 namespace torch::jit {
27 
28 namespace utils {
getNodesModuleHierarchy(const Node & n)29 std::string getNodesModuleHierarchy(const Node& n) {
30   if (!n.callstack().has_value()) {
31     return std::string();
32   }
33   InlinedCallStackPtr callstack_ptr = n.callstack().value();
34   std::string module_hierarchy;
35   for (auto& entry : callstack_ptr->vec()) {
36     const auto& opt_module_info = std::get<kModuleInstanceInfo>(entry);
37     if (opt_module_info.has_value()) {
38       const auto& module_instance_info = opt_module_info.value();
39       if (!module_hierarchy.empty()) {
40         module_hierarchy.append(".");
41       }
42       module_hierarchy.append(utils::get_module_info(module_instance_info));
43     } else {
44       module_hierarchy += ".UNKNOWN_INSTANCE(UNKNOWN_TYPE)";
45     }
46   }
47   return module_hierarchy;
48 }
49 } // namespace utils
50 
51 namespace {
52 
53 // Constants relating to maintaining the topological index of nodes.
54 //
55 // Lower and upper bounds of the index. Inclusive range.
56 constexpr topo_position_t kLowerBound = INT64_MIN;
57 constexpr topo_position_t kUpperBound = INT64_MAX;
58 constexpr topo_position_t kMidPoint = 0;
59 
60 // How far away to space nodes that are appended to the graph.
61 // should be 2^n, where:
62 //   - n is the maximum number of repeated insertions without a re-index
63 //   - 2^(64-n) is the maximum number of appends to the end without reindex
64 constexpr topo_position_t kAppendInterval = 1099511627776ULL /* 2^40 */;
65 
printValueRef(std::ostream & out,const Value * n)66 void printValueRef(std::ostream& out, const Value* n) {
67   out << "%" << n->debugName();
68 }
69 
isNumber(c10::string_view str)70 bool isNumber(c10::string_view str) {
71   return str.find_first_not_of("0123456789") == std::string::npos;
72 }
73 
normalizeAttrName(c10::string_view field)74 std::string normalizeAttrName(c10::string_view field) {
75   if (isNumber(field)) {
76     return "_" + std::string{field};
77   }
78   return std::string{field};
79 }
80 
findAllNodes(Block & block,Symbol kind,bool recurse,std::vector<Node * > & ret)81 void findAllNodes(
82     Block& block,
83     Symbol kind,
84     bool recurse,
85     std::vector<Node*>& ret) {
86   for (Node* n : block.nodes()) {
87     if (n->kind() == kind) {
88       ret.push_back(n);
89     }
90     if (recurse) {
91       for (auto b : n->blocks()) {
92         findAllNodes(*b, kind, recurse, ret);
93       }
94     }
95   }
96 }
97 
98 } // namespace
99 
100 // NB: This overload will become ambiguous with the one Caffe2 provides in its
101 // logging, if they ever intersect.
102 template <typename T>
operator <<(std::ostream & out,const std::vector<T> & nodes)103 std::ostream& operator<<(std::ostream& out, const std::vector<T>& nodes) {
104   out << at::ArrayRef<T>{nodes};
105   return out;
106 }
107 
108 template <typename T>
printValueRefs(std::ostream & out,const at::ArrayRef<T> nodes)109 static std::ostream& printValueRefs(
110     std::ostream& out,
111     const at::ArrayRef<T> nodes) {
112   size_t i = 0;
113   for (auto n : nodes) {
114     if (i++ > 0) {
115       out << ", ";
116     }
117     printValueRef(out, n);
118   }
119   return out;
120 }
121 
122 // Can't make these two overloads directly a template, it'll be ambiguous with
123 // the global printer for operator<<.
124 
operator <<(std::ostream & out,const at::ArrayRef<const Value * > nodes)125 static std::ostream& operator<<(
126     std::ostream& out,
127     const at::ArrayRef<const Value*> nodes) {
128   return printValueRefs(out, nodes);
129 }
130 
131 struct const_value_list_with_types {
132   const ArrayRef<const Value*> values;
133   std::string delim;
const_value_list_with_typestorch::jit::const_value_list_with_types134   const_value_list_with_types(
135       ArrayRef<const Value*> values,
136       std::string delim_ = ", ")
137       : values(values), delim(std::move(delim_)) {}
138 };
139 
operator <<(std::ostream & out,const const_value_list_with_types & l)140 static std::ostream& operator<<(
141     std::ostream& out,
142     const const_value_list_with_types& l) {
143   size_t i = 0;
144   for (auto n : l.values) {
145     if (i++ > 0) {
146       out << l.delim;
147     }
148     printValueRef(out, n);
149     if (c10::type_verbosity() >= c10::TypeVerbosity::Type) {
150       out << " : ";
151       out << *n->type();
152     }
153   }
154   return out;
155 }
156 
printAttribute(std::ostream & out,const at::Tensor & tensor)157 static void printAttribute(std::ostream& out, const at::Tensor& tensor) {
158   // 1-elem tensors are usually boxed scalars, so print them like it
159   if (tensor.numel() == 1) {
160     auto scalar_tensor = tensor.view(std::vector<int64_t>{}).item();
161     out << "{";
162     if (scalar_tensor.isFloatingPoint()) {
163       out << scalar_tensor.toDouble();
164     } else if (scalar_tensor.isComplex()) {
165       out << scalar_tensor.toComplexDouble();
166     } else {
167       out << scalar_tensor.toLong();
168     }
169     out << "}";
170   } else if (tensor.numel() <= max_tensor_display_size) {
171     // TODO: This is awful code.  Also it doesn't work on Windows.
172     std::ostringstream tensor_ss;
173     tensor_ss << tensor;
174     std::string tensor_s{tensor_ss.str()};
175     // Remove newlines
176     std::replace(tensor_s.begin(), tensor_s.end(), '\n', ' ');
177     out << tensor_s;
178   } else {
179     out << "<Tensor>";
180   }
181 }
182 
printAttribute(std::ostream & out,const IValue & ival)183 static void printAttribute(std::ostream& out, const IValue& ival) {
184   const auto customFormatter = [](std::ostream& ss, const IValue& input) {
185     if (input.isTensor()) {
186       printAttribute(ss, input.toTensor());
187       return true;
188     } else if (input.isTensorList()) {
189       ss << "[<Tensors>]";
190       return true;
191     } else if (input.isObject() && !input.type()->is_module()) {
192       ss << "object(" << &input.toObjectRef() << ")";
193       return true;
194     }
195     return false;
196   };
197   ival.repr(out, customFormatter);
198 }
199 
printTypeList(std::ostream & out,const std::vector<TypePtr> & items)200 static void printTypeList(
201     std::ostream& out,
202     const std::vector<TypePtr>& items) {
203   out << "[";
204   int i = 0;
205   for (auto& item : items) {
206     if (i++ > 0)
207       out << ", ";
208     out << *item;
209   }
210   out << "]";
211 }
212 
printAttrValue(std::ostream & out,const Symbol & name) const213 void Node::printAttrValue(std::ostream& out, const Symbol& name) const {
214   switch (kindOf(name)) {
215     case AttributeKind::c:
216       printAttribute(out, c(name));
217       break;
218     case AttributeKind::cs:
219       // TODO(@anjali411): fix this
220       AT_ASSERT(false);
221       break;
222     case AttributeKind::f:
223       printAttribute(out, f(name));
224       break;
225     case AttributeKind::fs:
226       printAttribute(out, fs(name));
227       break;
228     case AttributeKind::i:
229       printAttribute(out, i(name));
230       break;
231     case AttributeKind::is:
232       printAttribute(out, is(name));
233       break;
234     case AttributeKind::s:
235       printAttribute(out, s(name));
236       break;
237     case AttributeKind::ss:
238       printAttribute(out, ss(name));
239       break;
240     case AttributeKind::t:
241       printAttribute(out, t(name));
242       break;
243     case AttributeKind::ts:
244       out << "[<Tensors>]";
245       break;
246     case AttributeKind::ival:
247       printAttribute(out, ival(name));
248       break;
249     case AttributeKind::g:
250       out << "<Graph>";
251       break;
252     case AttributeKind::gs:
253       out << "[<Graphs>]";
254       break;
255     case AttributeKind::ty:
256       out << *ty(name);
257       break;
258     case AttributeKind::tys:
259       printTypeList(out, tys(name));
260       break;
261   }
262 }
263 
printAttributes(std::ostream & out,bool ignore_subgraph=false) const264 void Node::printAttributes(std::ostream& out, bool ignore_subgraph = false)
265     const {
266   out << "[";
267   auto names = attributeNames();
268   int i = 0;
269   for (auto name : names) {
270     if (ignore_subgraph && name == attr::Subgraph) {
271       continue;
272     }
273     if (i++ > 0) {
274       out << ", ";
275     }
276     // TODO: debugging mode to see the qualifier.  We definitely
277     // don't want to print the qualifier since it should always
278     // be attribute, but you might be able to track down a weird
279     // bug by printing it out.
280     out << name.toUnqualString() << "=";
281 
282     printAttrValue(out, name);
283   }
284   out << "]";
285 }
286 
sourceRange() const287 SourceRange Node::sourceRange() const {
288   if (source_range_) {
289     return *source_range_;
290   }
291   return SourceRange();
292 }
293 
indent(std::ostream & out,size_t level)294 static std::ostream& indent(std::ostream& out, size_t level) {
295   for (const auto i : c10::irange(level)) {
296     (void)i; // Suppress unused variable warning
297     out << "  ";
298   }
299   return out;
300 }
301 
print(std::ostream & out,size_t level,std::vector<const Node * > * groups,bool print_source_locations,bool print_attributes,bool print_scopes,bool print_body) const302 std::ostream& Node::print(
303     std::ostream& out,
304     size_t level,
305     std::vector<const Node*>* groups,
306     bool print_source_locations,
307     bool print_attributes,
308     bool print_scopes,
309     bool print_body) const {
310   auto outs = outputs();
311   indent(out, level) << const_value_list_with_types(outs);
312   out << " = ";
313   if (kind() == prim::PythonOp) {
314     auto* pyOp = static_cast<const ::torch::jit::PythonOp*>(this);
315     out << "^" << pyOp->name();
316     printAttributes(out, /*ignore_subgraph=*/false);
317     pyOp->writeScalars(out);
318   } else if (hasAttribute(attr::Subgraph) && groups) {
319     out << kind().toQualString() << "_" << groups->size();
320     if (print_attributes && numAttributes() > 1 &&
321         kind() != prim::DifferentiableGraph) {
322       printAttributes(out, /*ignore_subgraph=*/true);
323     }
324 
325     groups->push_back(this);
326   } else {
327     out << kind().toQualString();
328     if (print_attributes && hasAttributes()) {
329       printAttributes(out);
330     }
331   }
332   out << "(" << inputs() << ")";
333 
334   if (print_scopes) {
335     std::string scName = scopeName();
336     if (!scName.empty()) {
337       out << ", ";
338       out << "scope: " << scName;
339     }
340   }
341 
342   // In debug print, append file:line:col as a comment after each node
343   if (print_source_locations) {
344     SourceRange r = sourceRange();
345     if (sourceRange().source()) {
346       if (auto orig = sourceRange().source()->findSourceRangeThatGenerated(r)) {
347         r = *orig;
348       }
349     }
350     if (auto file_line_col = r.file_line_col()) {
351       auto [filename, line, col] = *file_line_col;
352       out << " # " << filename << ":" << line << ":" << col;
353     }
354   }
355 
356   if (!print_body) {
357     return out;
358   }
359 
360   out << "\n";
361 
362   for (const auto i : c10::irange(blocks().size())) {
363     auto b = blocks()[i];
364     indent(out, level + 1) << "block" << i << "("
365                            << const_value_list_with_types(b->inputs())
366                            << "):\n";
367     for (auto nested : b->nodes()) {
368       nested->print(out, level + 2, groups);
369     }
370     indent(out, level + 2) << "-> (" << b->outputs() << ")\n";
371   }
372 
373   return out;
374 }
375 
operator <<(std::ostream & out,const Node & n)376 std::ostream& operator<<(std::ostream& out, const Node& n) {
377   return n.print(out, 0, nullptr);
378 }
379 
print(std::ostream & out,bool print_source_locations) const380 std::ostream& Graph::print(std::ostream& out, bool print_source_locations)
381     const {
382   out << "graph(" << const_value_list_with_types(inputs(), ",\n      ")
383       << "):\n";
384   std::vector<const Node*> groups;
385   for (auto n : nodes()) {
386     n->print(out, 1, &groups, print_source_locations);
387   }
388   out << "  return (" << outputs() << ")\n";
389   size_t i = 0;
390   for (auto fg : groups) {
391     out << "with " << fg->kind().toQualString() << "_" << i++ << " = "
392         << *fg->g(attr::Subgraph);
393   }
394   out.flush();
395 
396   /*
397   // Uncomment this to debug all_nodes issues
398   {
399     out << "\n";
400     out << "all_nodes:\n";
401     for (auto& n : all_nodes) {
402       printNode(out, const_cast<Node*>(n), nullptr);
403     }
404   }
405   */
406   return out;
407 }
408 
operator <<(std::ostream & out,const Graph & g)409 std::ostream& operator<<(std::ostream& out, const Graph& g) {
410   return g.print(out, true);
411 }
412 
checkSameDevice(const Node * node)413 static void checkSameDevice(const Node* node) {
414   bool has_device = false;
415   std::optional<at::Device> device = std::nullopt;
416   auto checkValue = [&](const Value* v) {
417     if (TensorTypePtr type = v->type()->cast<TensorType>()) {
418       if (type->device() && !has_device) {
419         has_device = true;
420         device = *type->device();
421       } else {
422         AT_ASSERT(device == type->device());
423       }
424     }
425   };
426   for (auto input : node->inputs()) {
427     checkValue(input);
428   }
429   for (auto output : node->outputs()) {
430     checkValue(output);
431   }
432 }
433 
434 using node_set = std::set<const Node*>;
435 #define ALL_OF(container) container.begin(), container.end()
436 
437 // These functions purposely operate on the internal members directly, to force
438 // you to think about how the invariants change if you change the data
439 // representation (even if the external API does not change.)
440 
441 // NB: This assert is written to assume you don't have any unattached
442 // nodes.  Unattached nodes can occur while manipulations to the
443 // graph are occurring.
lint() const444 void Node::lint() const {
445   // Node invariants
446   // - if node should live in list, nodes_iter is consistent
447   // - Inputs are all marked as a use by the nodes they refer to
448   // - Owning graph is non-null and consistent
449   // - The "Select" invariant, when the node is MultiReturn
450   //
451   // The handle invariant:
452   //    If a node takes a handle as an input, it is always the
453   //    LAST input of the node.  There is at most one handle input.
454 
455   {
456     size_t i = 0;
457     for (auto input : inputs_) {
458       // WARNING: O(n^2)
459       // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
460       AT_ASSERT(
461           std::find(ALL_OF(input->uses_), Use(const_cast<Node*>(this), i)) !=
462           input->uses_.end());
463       AT_ASSERT(graph_->all_nodes.count(this) == 1);
464       i++;
465     }
466   }
467 
468   for (auto o : outputs()) {
469     for (auto use : o->uses()) {
470       // Use invariants
471       // - Use is consistent with inputs
472       // - Every user node is live (checked in Graph)
473       AT_ASSERT(use.user->inputs_[use.offset] == o);
474     }
475   }
476 
477   // Node subclass invariants
478   switch (kind()) {
479     case prim::Constant:
480       AT_ASSERT(inputs_.empty());
481       break;
482     case prim::Return:
483       // Return uses is zero
484       AT_ASSERT(outputs().empty());
485       break;
486     case prim::Param:
487       // Param inputs is zero
488       AT_ASSERT(inputs_.empty());
489       break;
490     case prim::PythonOp: {
491       // Python operator cconv is correct
492       auto* value = static_cast<const PythonOp*>(this);
493       value->lint_python();
494       break;
495     }
496     case prim::Eval:
497       // TODO: add invariants
498       // TODO: It's not good for these ops to be top-level, it makes cases
499       // longer.
500       break;
501     case prim::FusionGroup:
502     case prim::CudaFusionGroup:
503     case prim::oneDNNFusionGroup:
504       checkSameDevice(this);
505       // TODO: Typecheck the parameters
506       g(attr::Subgraph)->lint();
507       break;
508   }
509 }
510 
511 // TODO: When lint fails, give better indication about which
512 // instruction triggered the failure.
lint() const513 void Graph::lint() const {
514   // Graph invariants
515 
516   // Uncomment the following to see the graph
517   // std::cout << *const_cast<Graph*>(this);
518 
519   // nodes
520   // - nodes_ is a valid topological ordering for inputs
521   // - No repeated nodes
522   // - Params and return do NOT occur in nodes
523   // - next_unique_ is greater than all uniques in graph
524   // - uniques in all_nodes are unique
525   // - every use will occur later in the toposort
526 
527   struct LintScope {
528     LintScope() = default;
529     LintScope(std::unique_ptr<LintScope> parent) : parent(std::move(parent)) {}
530     bool contains(const Value* v) {
531       return values.count(v) > 0 || (parent && parent->contains(v));
532     }
533     bool contains(const Node* n) {
534       return nodes.count(n) > 0 || (parent && parent->contains(n));
535     }
536     void insert(const Value* v) {
537       AT_ASSERT(!contains(v));
538       values.insert(v);
539     }
540     void insert(const Node* n) {
541       AT_ASSERT(!contains(n));
542       nodes.insert(n);
543     }
544     // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
545     std::unique_ptr<LintScope> parent;
546 
547    private:
548     std::unordered_set<const Value*> values;
549     std::unordered_set<const Node*> nodes;
550   };
551   // Struct enables mutual recursion in linting methods.
552   // Putting it inside Graph::lint enables access to private Graph members
553   struct LintImpl {
554     LintImpl(const Graph& g)
555         : g(g),
556           scope(new LintScope()),
557           all_nodes_set(ALL_OF(g.all_nodes)) {} // NB: all_nodes is *unordered*
558     const Graph& g;
559     std::unique_ptr<LintScope> scope;
560     std::unordered_set<size_t> seen_uniques;
561     std::unordered_map<const Node*, int64_t> anticipated_uses;
562     node_set all_nodes_set;
563     node_set sum_set;
564 
565     void check_value(const Value* v) {
566       scope->insert(v);
567       auto b2 = seen_uniques.insert(v->unique());
568       AT_ASSERT(b2.second); // insertion took place
569       AT_ASSERT(v->unique() < g.next_unique_);
570 
571       for (auto use : v->uses()) {
572         AT_ASSERT(!scope->contains(use.user));
573         AT_ASSERT(g.all_nodes.count(use.user) == 1);
574         anticipated_uses[use.user]++; // int default constructs to 0
575       }
576     }
577     void check_node(const Node* n) {
578       for (auto input : n->inputs_) {
579         if (!scope->contains(input)) {
580           AT_ASSERTM(0, input->unique(), " not in scope");
581         }
582       }
583       AT_ASSERT(anticipated_uses[n] == static_cast<int64_t>(n->inputs_.size()));
584       anticipated_uses[n] = -1; // we saw the anticipated user!
585       scope->insert(n);
586       for (auto block : n->blocks()) {
587         scope = std::make_unique<LintScope>(std::move(scope));
588         check_block(block);
589         scope = std::move(scope->parent);
590       }
591       size_t i = 0;
592       for (auto o : n->outputs()) {
593         AT_ASSERT(o->node() == n);
594         AT_ASSERT(i++ == o->offset_);
595         check_value(o);
596       }
597       n->lint();
598     }
599     void check_block(const Block* b) {
600       // Check topological ordering
601       AT_ASSERT(b->param_node()->isBefore(*b->nodes().begin()));
602       auto curNode = *b->nodes().begin();
603       while (curNode != b->return_node()) {
604         AT_ASSERT(curNode->isBefore(curNode->next()));
605         curNode = curNode->next();
606       }
607 
608       for (auto input : b->inputs()) {
609         check_value(input);
610         AT_ASSERT(input->node()->kind_ == prim::Param);
611       }
612 
613       for (auto n : b->nodes()) {
614         AT_ASSERT(n->kind_ != prim::Param);
615         AT_ASSERT(n->kind_ != prim::Return);
616         check_node(n);
617       }
618 
619       AT_ASSERT(b->output_->kind() == prim::Return);
620       check_node(b->output_);
621 
622       // all_nodes
623       // - inputs_, output_ and nodes_ are all included in all_nodes
624       // - all_nodes does not contain dead nodes??? (likely to be temporarily
625       // suspended).  Weaker: all_nodes contains all inputs and returns
626       // - only one return node???
627 
628       node_set nodes_set(ALL_OF(b->nodes()));
629       node_set inputs_set{b->input_};
630       node_set output_set{b->output_};
631       // TODO: Make a more type safe std::includes wrapper which disallows use
632       // on non-ordered containers
633       AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(nodes_set)));
634       AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(inputs_set)));
635       AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(output_set)));
636 
637       sum_set.insert(ALL_OF(nodes_set));
638       sum_set.insert(ALL_OF(inputs_set));
639       sum_set.insert(ALL_OF(output_set));
640     }
641     void check_graph() {
642       node_set all_nodes_set(
643           ALL_OF(g.all_nodes)); // NB: all_nodes is *unordered*
644 
645       check_block(g.block_);
646       for (auto kv : anticipated_uses) {
647         AT_ASSERT(kv.second == -1);
648       }
649       AT_ASSERT(std::includes(ALL_OF(sum_set), ALL_OF(all_nodes_set)));
650     }
651   };
652   LintImpl(*this).check_graph();
653 }
654 
dump() const655 void Graph::dump() const {
656   std::cout << *this << "\n";
657 }
658 
push_scope(const std::string & scope_name)659 void Graph::push_scope(const std::string& scope_name) {
660   current_scope_ = current_scope_->push(Symbol::scope(scope_name));
661   Node* block_node = insertNode(create(prim::TracedModuleForward, 0));
662   block_node->s_(attr::scope, scope_name);
663   Block* b = block_node->addBlock();
664   setInsertPoint(b);
665 }
pop_scope()666 void Graph::pop_scope() {
667   current_scope_ = current_scope_->parent();
668   if (insertPoint()->owningBlock()->owningNode()->kind() ==
669       prim::TracedModuleForward) {
670     setInsertPoint(insertPoint()->owningBlock()->owningNode()->next());
671   }
672 }
673 
LintGraph(const std::shared_ptr<Graph> & graph)674 void LintGraph(const std::shared_ptr<Graph>& graph) {
675   graph->lint();
676 }
677 
Block(Graph * graph_,Node * node_)678 Block::Block(Graph* graph_, Node* node_)
679     : graph_(graph_),
680       output_(graph_->create(prim::Return, 0)),
681       input_(graph_->create(prim::Param, 0)),
682       owning_node_(node_) {
683   input_->next() = output_;
684   input_->prev() = output_;
685   output_->next() = input_;
686   output_->prev() = input_;
687 
688   graph_->all_blocks.emplace(this);
689   output_->owning_block_ = this;
690   output_->topo_position_ = kUpperBound;
691   input_->owning_block_ = this;
692   input_->topo_position_ = kLowerBound;
693 }
694 
reIndexTopology()695 void Block::reIndexTopology() {
696   auto curPos = kLowerBound;
697   for (auto node : nodes()) {
698     AT_ASSERT(curPos <= (kUpperBound - kAppendInterval));
699     curPos += kAppendInterval;
700     node->topo_position_ = curPos;
701   }
702 }
703 
cloneFrom(Block * src,std::function<Value * (Value *)> value_map)704 void Block::cloneFrom(Block* src, std::function<Value*(Value*)> value_map) {
705   std::unordered_map<Value*, Value*> local_map;
706   auto env = [&](Value* v) {
707     auto it = local_map.find(v);
708     if (it != local_map.end()) {
709       return it->second;
710     }
711     return value_map(v);
712   };
713 
714   auto graph = owningGraph();
715   for (auto input : src->inputs()) {
716     local_map[input] = this->addInput()->copyMetadata(input);
717   }
718 
719   for (auto node : src->nodes()) {
720     auto new_node = this->appendNode(graph->createClone(node, env));
721     for (size_t i = 0; i < node->outputs().size(); ++i) {
722       auto oo = node->outputs()[i];
723       auto no = new_node->outputs()[i];
724       local_map[oo] = no;
725       no->copyMetadata(oo);
726     }
727   }
728   for (auto output : src->outputs()) {
729     this->registerOutput(env(output));
730   }
731 }
732 
destroy()733 void Block::destroy() {
734   // we cannot destroy the output because it is used as the sentinel
735   // for the nodes() list and has to remain valid for the loop
736   output_->removeAllInputs();
737   for (auto it = this->nodes().reverse().begin(),
738             end = this->nodes().reverse().end();
739        it != end;
740        ++it) {
741     it.destroyCurrent();
742   }
743   output_->destroy();
744   input_->destroy();
745   graph_->freeBlock(this);
746 }
747 
cloneFrom(Graph & src)748 void Graph::cloneFrom(Graph& src) {
749   auto env = [](Value* v) -> Value* {
750     AT_ERROR(
751         "Graph::copy() encountered a use of a value " + v->debugName() +
752         " not in scope. Run lint!");
753   };
754   block()->cloneFrom(src.block(), env);
755 }
756 
copy()757 std::shared_ptr<Graph> Graph::copy() {
758   auto new_g = std::make_shared<Graph>();
759   new_g->cloneFrom(*this);
760   return new_g;
761 }
762 
copyUnique()763 std::unique_ptr<Graph> Graph::copyUnique() {
764   auto new_g = std::make_unique<Graph>();
765   new_g->cloneFrom(*this);
766   return new_g;
767 }
768 
remapTypes(const std::function<TypePtr (TypePtr)> & type_map)769 void Block::remapTypes(const std::function<TypePtr(TypePtr)>& type_map) {
770   for (Value* input : inputs()) {
771     input->setType(type_map(input->type()));
772   }
773   for (Node* node : nodes()) {
774     for (Value* output : node->outputs()) {
775       output->setType(type_map(output->type()));
776     }
777     for (Block* sub_block : node->blocks()) {
778       sub_block->remapTypes(type_map);
779     }
780     for (Symbol name : node->attributeNames()) {
781       if (node->kindOf(name) == AttributeKind::g) {
782         node->g(name)->remapTypes(type_map);
783       } else if (node->kindOf(name) == AttributeKind::gs) {
784         for (const auto& g : node->gs(name)) {
785           g->remapTypes(type_map);
786         }
787       }
788     }
789   }
790 }
791 
remapTypes(const std::function<TypePtr (TypePtr)> & type_map)792 void Graph::remapTypes(const std::function<TypePtr(TypePtr)>& type_map) {
793   block()->remapTypes(type_map);
794 }
795 
inferTypeFrom(const at::Tensor & output)796 void Value::inferTypeFrom(const at::Tensor& output) {
797   setType(TensorType::create(output));
798 }
799 
inferTypeFrom(const c10::intrusive_ptr<c10::ivalue::Object> & output)800 void Value::inferTypeFrom(
801     const c10::intrusive_ptr<c10::ivalue::Object>& output) {
802   setType(output->type());
803 }
804 
mustBeNone() const805 bool Value::mustBeNone() const {
806   return type()->cast<NoneType>() || node_->mustBeNone();
807 }
mustNotBeNone() const808 bool Value::mustNotBeNone() const {
809   return node_->kind() != prim::AutogradAdd && type() != NoneType::get() &&
810       !type()->cast<OptionalType>() &&
811       !(type()->cast<UnionType>() &&
812         type()->expect<UnionType>()->canHoldType(*NoneType::get()));
813 }
814 
debugNameBase() const815 std::string Value::debugNameBase() const {
816   std::string name = debugName();
817   std::string name_base = name;
818   auto last_dot_pos = name.find_last_of('.');
819   if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) {
820     if (name.find_first_not_of("0123456789", last_dot_pos + 1) ==
821         std::string::npos) {
822       name_base = name.substr(0, last_dot_pos);
823     }
824   }
825   return name_base;
826 }
827 
isValidName(const std::string & name)828 bool Value::isValidName(const std::string& name) {
829   // Empty strings are legal
830   if (name.empty()) {
831     return true;
832   }
833 
834   // Numbers are not legal
835   if (isNumber(name)) {
836     return false;
837   }
838 
839   return true;
840 }
841 
setDebugName(const std::string & name)842 Value* Value::setDebugName(const std::string& name) {
843   if (!isValidName(name)) {
844     throw std::runtime_error("Invalid name: '" + name + "'");
845   }
846 
847   auto& names = node()->owningGraph()->unique_names_;
848 
849   // clear any old name from the map
850   if (hasDebugName()) {
851     names.erase(unique_name_);
852     unique_name_ = "";
853   }
854 
855   // allow "" to clear the uniquename
856   if (name.empty()) {
857     return this;
858   }
859 
860   // if someone else has this name, then rename the other value
861   auto old_owner_of_name = names.find(name);
862   if (old_owner_of_name != names.end()) {
863     size_t suffix = 1;
864     std::string name_base = name;
865     auto last_dot_pos = name.find_last_of('.');
866     if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) {
867       if (name.find_first_not_of("0123456789", last_dot_pos + 1) ==
868           std::string::npos) {
869         suffix = std::stoll(name.substr(last_dot_pos + 1));
870         name_base = name.substr(0, last_dot_pos);
871       }
872     }
873 
874     auto& names_suffixes = node()->owningGraph()->name_base_suffix_;
875     auto it = names_suffixes.find(name_base);
876     if (it != names_suffixes.end()) {
877       suffix = std::max(suffix, it->second + 1);
878     }
879 
880     // Verify that new name is not used and find next usable name in case
881     // suffix is used.
882     std::string replacement_name;
883     do {
884       std::stringstream ss;
885 #ifndef _WIN32
886       // Protect 12345 integer from becoming "1,2345" if some other process sets
887       // global locale For more details see
888       // https://github.com/pytorch/pytorch/issues/79583#issuecomment-1161260061
889       static std::locale c_locale("C");
890       ss.imbue(c_locale);
891 #endif
892       ss << name_base << "." << suffix++;
893       replacement_name = ss.str();
894     } while (names.count(replacement_name) > 0);
895 
896     names_suffixes[name_base] = suffix;
897 
898     old_owner_of_name->second->setDebugName(replacement_name);
899   }
900 
901   names[name] = this;
902   unique_name_ = name;
903   return this;
904 }
905 
copyMetadata(Value * from)906 Value* Value::copyMetadata(Value* from) {
907   setType(from->type());
908   if (from->hasDebugName()) {
909     setDebugName(from->debugName());
910   }
911   return this;
912 }
913 
replaceFirstUseWith(Value * newValue)914 void Value::replaceFirstUseWith(Value* newValue) {
915   AT_ASSERT(owningGraph() == newValue->owningGraph());
916   auto u = uses()[0];
917   u.user->inputs_[u.offset] = newValue;
918   newValue->uses_.push_back(u);
919   uses_.erase(uses_.begin());
920 }
921 
replaceAllUsesWith(Value * newValue)922 void Value::replaceAllUsesWith(Value* newValue) {
923   while (!uses().empty()) {
924     replaceFirstUseWith(newValue);
925   }
926 }
927 
replaceAllUsesAfterNodeWith(const Node * node,Value * newValue)928 void Value::replaceAllUsesAfterNodeWith(const Node* node, Value* newValue) {
929   std::for_each(uses_.begin(), uses_.end(), [&node, newValue](Use& u) {
930     if (u.user->isAfter(node)) {
931       u.user->inputs_[u.offset] = newValue;
932       newValue->uses_.push_back(u);
933     }
934   });
935 
936   uses_.erase(
937       std::remove_if(
938           uses_.begin(),
939           uses_.end(),
940           [&node](const Use& u) { return u.user->isAfter(node); }),
941       uses_.end());
942 }
943 
replaceAllUsesDominatedByNodeWith(const Node * node,Value * newValue)944 void Value::replaceAllUsesDominatedByNodeWith(
945     const Node* node,
946     Value* newValue) {
947   std::for_each(uses_.begin(), uses_.end(), [&node, newValue](Use& u) {
948     if (u.user->isDominatedBy(node)) {
949       u.user->inputs_[u.offset] = newValue;
950       newValue->uses_.push_back(u);
951     }
952   });
953 
954   uses_.erase(
955       std::remove_if(
956           uses_.begin(),
957           uses_.end(),
958           [&node](const Use& u) { return u.user->isDominatedBy(node); }),
959       uses_.end());
960 }
961 
findArgument(const FunctionSchema & the_schema,const std::string & unqualName)962 static size_t findArgument(
963     const FunctionSchema& the_schema,
964     const std::string& unqualName) {
965   for (const auto i : c10::irange(the_schema.arguments().size())) {
966     const Argument* arg = &the_schema.arguments()[i];
967     if (arg->name() == unqualName) {
968       return i;
969     }
970   }
971   throw std::runtime_error(
972       std::string("Couldn't find an argument called ") + unqualName);
973 }
974 
findArgument(const FunctionSchema & the_schema,Symbol name)975 static size_t findArgument(const FunctionSchema& the_schema, Symbol name) {
976   const auto unqualName = name.toUnqualString();
977   return findArgument(the_schema, unqualName);
978 }
979 
get(Symbol name) const980 std::optional<IValue> Node::get(Symbol name) const {
981   return toIValue(namedInput(name));
982 }
983 
hasNamedInput(const std::string & name) const984 bool Node::hasNamedInput(const std::string& name) const {
985   for (const auto& argument : schema().arguments()) {
986     if (argument.name() == name) {
987       return true;
988     }
989   }
990   return false;
991 }
992 
namedInput(const std::string & unqualName) const993 Value* Node::namedInput(const std::string& unqualName) const {
994   return input(findArgument(schema(), unqualName));
995 }
namedInput(Symbol name) const996 Value* Node::namedInput(Symbol name) const {
997   return input(findArgument(schema(), name));
998 }
999 
matches(const FunctionSchema & schema) const1000 bool Node::matches(const FunctionSchema& schema) const {
1001   if (isBlockListedSchema(schema)) {
1002     return false;
1003   }
1004   // wrong name
1005   if (kind().toQualString() != schema.name()) {
1006     return false;
1007   }
1008   at::ArrayRef<const Value*> actuals = inputs();
1009   const auto& formals = schema.arguments();
1010 
1011   // not enough inputs
1012   if (actuals.size() < formals.size()) {
1013     return false;
1014   }
1015 
1016   TypeEnv type_env;
1017   for (const auto i : c10::irange(formals.size())) {
1018     auto formal = formals[i].type();
1019     const MatchTypeReturn matched_type =
1020         matchTypeVariables(formal, actuals[i]->type(), type_env);
1021     if (!matched_type.success()) {
1022       return false;
1023     }
1024 
1025     TypePtr resolved = tryEvalTypeVariables(formal, type_env);
1026     if (resolved) {
1027       formal = resolved;
1028     }
1029     // note: it is possible at this point that type variable matching has
1030     // not resolved all type variables, e.g. if None was matched to Optional[T]
1031     // we will not succeed at matching T. However None <: Optional[T] so this
1032     // check can still succeed.
1033 
1034     if (!actuals[i]->type()->isSubtypeOf(*formal)) {
1035       return false;
1036     }
1037   }
1038 
1039   // too many inputs
1040   if (!schema.is_vararg() && actuals.size() != formals.size()) {
1041     return false;
1042   }
1043 
1044   return true;
1045 }
1046 
matches(const char * signature_literal,at::ArrayRef<Symbol> const_inputs) const1047 bool Node::matches(
1048     const char* signature_literal,
1049     at::ArrayRef<Symbol> const_inputs) const {
1050   if (!matches(getOperatorForLiteral(signature_literal)->schema())) {
1051     return false;
1052   }
1053   for (Symbol s : const_inputs) {
1054     if (!is_constant(s)) {
1055       return false;
1056     }
1057   }
1058   return true;
1059 }
1060 
mustBeNone() const1061 bool Node::mustBeNone() const {
1062   // We can statically deduce this Node has returning None if:
1063   return
1064       // It's an AutogradZero node, or ...
1065       kind_ == prim::AutogradZero ||
1066       // It has only one output and that output is NoneType, or ...
1067       (outputs().size() == 1 && output()->type() == NoneType::get()) ||
1068       // It's a constant optional with no value in the attributes.
1069       (kind_ == prim::Constant && !this->hasAttributes() &&
1070        output()->type()->cast<OptionalType>());
1071 }
1072 
dump() const1073 void Node::dump() const {
1074   std::cout << *this << "\n";
1075 }
1076 
schema() const1077 const FunctionSchema& Node::schema() const {
1078   if (op_) {
1079     return op_->schema();
1080   }
1081   return getOperator().schema();
1082 }
1083 
maybeSchema() const1084 const FunctionSchema* Node::maybeSchema() const {
1085   if (auto op = maybeOperator()) {
1086     return &op->schema();
1087   }
1088   return nullptr;
1089 }
1090 
maybeOperator() const1091 const Operator* Node::maybeOperator() const {
1092   if (!op_) {
1093     const auto& candidates = getAllOperatorsFor(kind());
1094     for (const auto& candidate : candidates) {
1095       if (matches(candidate->schema())) {
1096         op_ = candidate.get();
1097         break;
1098       }
1099     }
1100   }
1101   return op_;
1102 }
1103 
getOperator() const1104 const Operator& Node::getOperator() const {
1105   const Operator* maybe = maybeOperator();
1106   if (maybe)
1107     return *maybe;
1108 
1109   auto er = ErrorReport(sourceRange());
1110   er << "Schema not found for node. File a bug report.\n";
1111   er << "Node: " << *this << "\n";
1112   er << "Input types:";
1113   for (const auto i : c10::irange(inputs().size())) {
1114     if (i > 0)
1115       er << ", ";
1116     er << *inputs()[i]->type();
1117   }
1118   const auto& candidates = getAllOperatorsFor(kind());
1119   if (!candidates.empty()) {
1120     er << "\ncandidates were:\n";
1121     for (auto& candidate : candidates) {
1122       er << "  " << candidate->schema() << "\n";
1123     }
1124   } else {
1125     er << "\nno candidates found\n";
1126   }
1127   er << "within the graph:\n";
1128   er << *owningGraph() << "\n";
1129   throw er;
1130 }
1131 
getOperation() const1132 Operation Node::getOperation() const {
1133   // note: some operators require the node to produce a runnable operation,
1134   // which is why 'this' is passed here. getOperator() ensures that 'this'
1135   // matches the schema of the returned operator.
1136   return getOperator().getOperation(this);
1137 }
1138 
isNondeterministic() const1139 bool Node::isNondeterministic() const {
1140   const auto schema = maybeSchema();
1141   if (!kind().is_aten()) {
1142     return false;
1143   }
1144   // All aten ops are expecte to have a schema. However this is left as a
1145   // warning instead of an assert to ensure that previous use cases do not
1146   // break.
1147   if (!schema) {
1148     TORCH_WARN("aten Schema not found.");
1149     return false;
1150   }
1151   torch::utils::SchemaInfo schema_info(*schema);
1152   if (hasNamedInput("train")) {
1153     auto value = constant_as<bool>(namedInput("train"));
1154     if (value.has_value()) {
1155       schema_info.addArgumentValue("train", *value);
1156     }
1157   }
1158   return schema_info.is_nondeterministic();
1159 }
1160 
hasSideEffects() const1161 bool Node::hasSideEffects() const {
1162   switch (kind_) {
1163     case prim::PythonOp:
1164     case prim::IgnoredPythonOp:
1165     case prim::Print:
1166     case prim::RaiseException:
1167     case aten::warn:
1168     case aten::save:
1169     case aten::manual_seed:
1170     case prim::AddStatValue:
1171     case prim::TimePoint:
1172     case prim::CallFunction:
1173     case prim::CallMethod:
1174     case prim::BailoutTemplate:
1175     case prim::BailOut:
1176     case prim::rpc_async: // It represents RPC message sent.
1177     case prim::rpc_sync: // It represents RPC message sent.
1178     case prim::rpc_remote: // It represents RPC message sent.
1179     case aten::wait: // It can represent RPC message received.
1180 #if !defined(USE_ROCM)
1181     case cuda::set_stream:
1182     case cuda::_set_device:
1183     case cuda::_current_device:
1184     case cuda::synchronize:
1185 #endif
1186     case prim::Enter:
1187     case prim::Exit:
1188       return true;
1189   }
1190 
1191   auto op = maybeOperator();
1192   if (!op) {
1193     TORCH_INTERNAL_ASSERT(
1194         kind_.is_prim(),
1195         "Only prim ops are allowed to not have a registered operator but ",
1196         kind_.toDisplayString(),
1197         " doesn't have one either. We don't know if this op has side effects.");
1198     return false;
1199   }
1200 
1201   if (kind_.is_prim() || kind_.is_aten() || kind_.is_cuda()) {
1202     // TODO There is nothing in the system that relies on aten:: and prim::
1203     // ops using AliasAnalysisKind::FROM_SCHEMA,
1204     // AliasAnalysisKind::INTERNAL_SPECIAL_CASE, or
1205     // AliasAnalysisKind::CONSERVATIVE but this is the intended behavior for all
1206     // current ops and a good error check. We can consider lifting this
1207     // constraint later if we have a use case for it.
1208     TORCH_INTERNAL_ASSERT(
1209         op->aliasAnalysisKind() == AliasAnalysisKind::INTERNAL_SPECIAL_CASE ||
1210             op->aliasAnalysisKind() == AliasAnalysisKind::FROM_SCHEMA ||
1211             op->aliasAnalysisKind() == AliasAnalysisKind::CONSERVATIVE,
1212         "aten:: and prim:: ops should have AliasAnalysisKind::INTERNAL_SPECIAL_CASE"
1213         ", AliasAnalysisKind::FROM_SCHEMA or AliasAnalysisKind::CONSERVATIVE but ",
1214         kind_.toDisplayString(),
1215         " has ",
1216         toString(op->aliasAnalysisKind()));
1217   }
1218 
1219   switch (op->aliasAnalysisKind()) {
1220     case AliasAnalysisKind::PURE_FUNCTION:
1221     case AliasAnalysisKind::FROM_SCHEMA:
1222     case AliasAnalysisKind::INTERNAL_SPECIAL_CASE:
1223       return false;
1224     case AliasAnalysisKind::CONSERVATIVE:
1225       return true;
1226   }
1227   TORCH_INTERNAL_ASSERT(false, "Unhandled AliasAnalysisKind case");
1228   return false; // silence compiler warning
1229 }
1230 
1231 // Assign this node a topological position, to facilitate fast isBefore() and
1232 // isAfter() queries. Must be called right after a node is inserted into the
1233 // node list.
1234 //
1235 // The basic scheme is: assign every node a position (uint64_t).  The common
1236 // case (appending to the end of the graph) is made more efficient by advancing
1237 // a fixed interval past the previous node and placing `this` there. Otherwise,
1238 // assign `this` a position at the midpoint between its prev() and next()
1239 // nodes.
1240 //
1241 // If we ever run out of space (by, e.g. inserting too much in place), we
1242 // reindex by spreading out all the nodes again.
assignTopoPosition()1243 void Node::assignTopoPosition() {
1244   bool is_first = prev() == owningBlock()->param_node();
1245   bool is_last = next() == owningBlock()->return_node();
1246 
1247   const auto prevPos = prev()->topo_position_;
1248   const auto nextPos = next()->topo_position_;
1249 
1250   // Append to the end of the graph
1251   if (is_last) {
1252     if (is_first) {
1253       // the node list is empty, assign the first position
1254       topo_position_ = kMidPoint;
1255       return;
1256     }
1257 
1258     if (prevPos >= (kUpperBound - kAppendInterval)) {
1259       // we're running off the edge
1260       owningBlock()->reIndexTopology();
1261       return;
1262     }
1263 
1264     topo_position_ = prevPos + kAppendInterval;
1265 
1266     // Prepend to the graph
1267   } else if (is_first) {
1268     // next() is the first element in the block list
1269     if (nextPos <= (kLowerBound + kAppendInterval)) {
1270       // we're running off the edge
1271       owningBlock()->reIndexTopology();
1272       return;
1273     }
1274     topo_position_ = nextPos - kAppendInterval;
1275 
1276     // insert between two existing nodes
1277   } else {
1278     int64_t remaining = nextPos - prevPos;
1279     AT_ASSERT(remaining > 0);
1280     if (remaining == 1) {
1281       // There was no room
1282       owningBlock()->reIndexTopology();
1283       return;
1284     }
1285     int64_t predicted_future_insertions = 0;
1286     if (next() == graph_->insertPoint()) {
1287       predicted_future_insertions = graph_->predicted_insert_count_++;
1288     }
1289     topo_position_ = prevPos +
1290         std::max(int64_t(1), remaining / (2 + predicted_future_insertions));
1291     AT_ASSERT(prevPos < topo_position_ && topo_position_ < nextPos);
1292   }
1293 }
1294 
Node(Graph * graph_,NodeKind kind_)1295 Node::Node(Graph* graph_, NodeKind kind_)
1296     : kind_(kind_),
1297       graph_(graph_),
1298       owning_block_(nullptr),
1299       scope_(graph_->current_scope_),
1300       callstack_(std::nullopt),
1301       op_(nullptr) {
1302   graph_->all_nodes.emplace(this);
1303 }
1304 
eraseOutput(size_t i)1305 void Node::eraseOutput(size_t i) {
1306   AT_ASSERT(i < outputs_.size());
1307   AT_ASSERT(outputs_[i]->uses().empty());
1308   op_ = nullptr;
1309   Value* n = outputs_[i];
1310   outputs_.erase(outputs_.begin() + i);
1311   owningGraph()->freeValue(n);
1312   for (const auto j : c10::irange(i, outputs_.size())) {
1313     outputs_[j]->offset_--;
1314   }
1315 }
1316 
addBlock()1317 Block* Node::addBlock() {
1318   op_ = nullptr;
1319   blocks_.push_back(new Block(owningGraph(), this));
1320   return blocks_.back();
1321 }
1322 
eraseBlock(size_t i)1323 void Node::eraseBlock(size_t i) {
1324   AT_ASSERT(i < blocks_.size());
1325   op_ = nullptr;
1326   Block* n = blocks_[i];
1327   blocks_.erase(blocks_.begin() + i);
1328   n->destroy();
1329 }
1330 
destroy()1331 void Node::destroy() {
1332   while (!outputs().empty()) {
1333     eraseOutput(outputs().size() - 1);
1334   }
1335   while (!blocks().empty()) {
1336     eraseBlock(blocks().size() - 1);
1337   }
1338   removeAllInputs();
1339   if (inBlockList()) {
1340     removeFromList();
1341   }
1342   graph_->freeNode(this);
1343 }
1344 
cloneFrom(Node * s)1345 void Node::cloneFrom(Node* s) {
1346   source_range_ = s->source_range_;
1347   if (s->scope_ && !s->scope_->isBlank()) {
1348     scope_ = s->scope_;
1349   }
1350   copyAttributes(*s);
1351   callstack_ = s->callstack_;
1352 }
1353 
replaceAllUsesWith(Node * n)1354 void Node::replaceAllUsesWith(Node* n) {
1355   AT_ASSERT(outputs().size() == n->outputs().size());
1356   size_t nOutputs = outputs().size();
1357   for (const auto i : c10::irange(nOutputs)) {
1358     outputs()[i]->replaceAllUsesWith(n->outputs()[i]);
1359   }
1360 }
1361 
replaceWithNewSymbol(Symbol new_symbol)1362 Node* Node::replaceWithNewSymbol(Symbol new_symbol) {
1363   WithInsertPoint insert_guard{this};
1364   bool had_operator = maybeOperator() != nullptr;
1365   auto graph = owningGraph();
1366   auto replace_node = graph->insertNode(graph->create(new_symbol, 0));
1367   for (Value* v : inputs()) {
1368     replace_node->addInput(v);
1369   }
1370   for (Value* v : outputs()) {
1371     auto new_out = replace_node->addOutput()->copyMetadata(v);
1372     v->replaceAllUsesWith(new_out);
1373   }
1374   replace_node->copyMetadata(this);
1375   replace_node->copyAttributes(*this);
1376   TORCH_INTERNAL_ASSERT(
1377       (replace_node->maybeOperator() != nullptr) == had_operator,
1378       "invalid symbol replacement:",
1379       new_symbol,
1380       kind());
1381   return replace_node;
1382 }
1383 
isDominatedBy(const Node * dominator) const1384 bool Node::isDominatedBy(const Node* dominator) const {
1385   const Node* node = this;
1386   while (node) {
1387     if (node->owningBlock() == dominator->owningBlock()) {
1388       return dominator->isBefore(node);
1389     }
1390     node = node->owningBlock()->owningNode();
1391   }
1392   return false;
1393 }
1394 
insertInput(size_t i,Value * value)1395 Value* Node::insertInput(size_t i, Value* value) {
1396   AT_ASSERT(graph_ == value->owningGraph());
1397   op_ = nullptr;
1398   // First we update the offsets for all existing inputs that will reside
1399   // after the one we're inserting. Concretely, these are the inputs at
1400   // indices [i, # input). Since we're inserting one input before all of
1401   // these inputs, increment their use offsets for this value by 1
1402   for (const auto use_itr : c10::irange(i, inputs_.size())) {
1403     // See Note [User node does not uniquely identify use]
1404     auto use = findUseForInput(use_itr);
1405     use->offset += 1;
1406   }
1407   // Insert the actual input at the specified index
1408   inputs_.insert(inputs_.begin() + i, value);
1409   // Register the new use of the value we're inserted as an input.
1410   value->uses_.emplace_back(this, i);
1411   return value;
1412 }
1413 
addInput(Value * value)1414 Value* Node::addInput(Value* value) {
1415   AT_ASSERT(graph_ == value->owningGraph());
1416   op_ = nullptr;
1417   value->uses_.emplace_back(this, inputs_.size());
1418   inputs_.push_back(value);
1419   return value;
1420 }
1421 
replaceInput(size_t i,Value * newValue)1422 Value* Node::replaceInput(size_t i, Value* newValue) {
1423   AT_ASSERT(newValue->owningGraph() == graph_);
1424   op_ = nullptr;
1425   Value* old = dropInput(i);
1426   inputs_[i] = newValue;
1427   newValue->uses_.emplace_back(this, i);
1428   return old;
1429 }
1430 
replaceInputWith(Value * from,Value * to)1431 void Node::replaceInputWith(Value* from, Value* to) {
1432   AT_ASSERT(from->owningGraph() == graph_);
1433   AT_ASSERT(to->owningGraph() == graph_);
1434   op_ = nullptr;
1435   size_t i = 0;
1436   for (auto input : inputs()) {
1437     if (input == from) {
1438       replaceInput(i, to);
1439     }
1440     i++;
1441   }
1442 }
1443 
addOutput()1444 Value* Node::addOutput() {
1445   outputs_.push_back(new Value(this, outputs_.size()));
1446   op_ = nullptr;
1447   return outputs_.back();
1448 }
1449 
insertOutput(size_t i)1450 Value* Node::insertOutput(size_t i) {
1451   op_ = nullptr;
1452   outputs_.insert(outputs_.begin() + i, new Value(this, i));
1453   for (size_t itr = i + 1; itr < outputs_.size(); ++itr) {
1454     outputs_[itr]->setOffset(outputs_[itr]->offset() + 1);
1455   }
1456   return outputs_.at(i);
1457 }
1458 
isBeforeOrAfter(const Node * n,MoveSide moveSide) const1459 bool Node::isBeforeOrAfter(const Node* n, MoveSide moveSide) const {
1460   if (this->owningBlock() == n->owningBlock()) {
1461     if (moveSide == MoveSide::BEFORE) {
1462       return this->topo_position_ < n->topo_position_;
1463     }
1464 
1465     if (moveSide == MoveSide::AFTER) {
1466       return this->topo_position_ > n->topo_position_;
1467     }
1468 
1469     AT_ASSERT(this == n);
1470     return false;
1471   }
1472 
1473   // These nodes don't share a common block. Traverse the blockchains upward
1474   // until we find the first common block.
1475   auto lhs = this;
1476   while (lhs) {
1477     AT_ASSERT(lhs->owningBlock());
1478 
1479     auto rhs = n;
1480     while (rhs) {
1481       if (!rhs->owningBlock()) {
1482         break;
1483       }
1484 
1485       if (lhs->owningBlock() == rhs->owningBlock()) {
1486         return lhs->isBeforeOrAfter(rhs, moveSide);
1487       }
1488       rhs = rhs->owningBlock()->owningNode();
1489     }
1490 
1491     lhs = lhs->owningBlock()->owningNode();
1492   }
1493   // should never reach here, since both nodes are ultimately in the same graph
1494   AT_ASSERT(false);
1495 }
1496 
isBefore(const Node * n) const1497 bool Node::isBefore(const Node* n) const {
1498   return isBeforeOrAfter(n, MoveSide::BEFORE);
1499 }
1500 
isAfter(const Node * n) const1501 bool Node::isAfter(const Node* n) const {
1502   return isBeforeOrAfter(n, MoveSide::AFTER);
1503 }
1504 
insertBefore(Node * n)1505 Node* Node::insertBefore(Node* n) {
1506   AT_ASSERT(n->inBlockList());
1507   insertAfter(n->prev());
1508   return this;
1509 }
1510 
insertAfter(Node * n)1511 Node* Node::insertAfter(Node* n) {
1512   AT_ASSERT(!inBlockList() && n->inBlockList());
1513   AT_ASSERT(n->owningBlock());
1514   AT_ASSERTM(
1515       n->kind() != prim::Return,
1516       "Attempting to insert a Node after the Return node or before the Param node. Tried to insert",
1517       *this,
1518       " after ",
1519       *n,
1520       ".");
1521   this->owning_block_ = n->owningBlock();
1522   Node* next = n->next();
1523   n->next() = this;
1524   this->prev() = n;
1525   this->next() = next;
1526   next->prev() = this;
1527   assignTopoPosition();
1528   return this;
1529 }
1530 
moveAfter(Node * n)1531 void Node::moveAfter(Node* n) {
1532   removeFromList();
1533   insertAfter(n);
1534 }
1535 
moveBefore(Node * n)1536 void Node::moveBefore(Node* n) {
1537   removeFromList();
1538   insertBefore(n);
1539 }
1540 
removeInput(size_t i)1541 void Node::removeInput(size_t i) {
1542   op_ = nullptr;
1543   dropInput(i);
1544   // everything after this input shifts left,
1545   // so we need to update their use offsets to match
1546   for (size_t j = i + 1; j < inputs_.size(); j++) {
1547     auto it = findUseForInput(j);
1548     it->offset--;
1549   }
1550   inputs_.erase(inputs_.begin() + i);
1551 }
1552 
removeAllInputs()1553 void Node::removeAllInputs() {
1554   op_ = nullptr;
1555   for (const auto i : c10::irange(inputs().size())) {
1556     dropInput(i);
1557   }
1558   inputs_.clear();
1559 }
1560 
removeAllOutputs()1561 void Node::removeAllOutputs() {
1562   op_ = nullptr;
1563   size_t init_osize = outputs_.size();
1564   for (auto i : c10::irange(init_osize)) {
1565     eraseOutput(init_osize - i - 1);
1566   }
1567 }
1568 
permuteInputs(const std::vector<size_t> & new_order)1569 void Node::permuteInputs(const std::vector<size_t>& new_order) {
1570   op_ = nullptr;
1571   AT_ASSERT(new_order.size() == inputs_.size());
1572   std::vector<Value*> new_inputs;
1573   new_inputs.reserve(new_order.size());
1574   for (const auto i : c10::irange(new_order.size())) {
1575     AT_ASSERTM(inputs_.at(new_order[i]) != nullptr, "Repeated index");
1576     new_inputs.push_back(inputs_.at(new_order[i]));
1577     auto it = findUseForInput(new_order[i]);
1578     it->offset = i;
1579     inputs_.at(new_order[i]) = nullptr;
1580   }
1581   inputs_ = std::move(new_inputs);
1582 }
1583 
permuteOutputs(const std::vector<size_t> & new_order)1584 void Node::permuteOutputs(const std::vector<size_t>& new_order) {
1585   op_ = nullptr;
1586   AT_ASSERT(new_order.size() == outputs_.size());
1587   std::vector<Value*> new_outputs;
1588   new_outputs.reserve(new_order.size());
1589   for (const auto i : c10::irange(new_order.size())) {
1590     AT_ASSERTM(outputs_.at(new_order[i]) != nullptr, "Repeated index");
1591     new_outputs.push_back(outputs_.at(new_order[i]));
1592     outputs_.at(new_order[i])->setOffset(i);
1593     outputs_.at(new_order[i]) = nullptr;
1594   }
1595   outputs_ = std::move(new_outputs);
1596 }
1597 
findUseForInput(size_t i)1598 use_list::iterator Node::findUseForInput(size_t i) {
1599   auto& input_uses = inputs_[i]->uses_;
1600   // O(N) on the use list, but unless we get nodes with +100 uses
1601   // vector traversal still is probably faster than linked list
1602   auto use_it = std::find(input_uses.begin(), input_uses.end(), Use(this, i));
1603   AT_ASSERT(use_it != input_uses.end());
1604   return use_it;
1605 }
1606 
dropInput(size_t i)1607 Value* Node::dropInput(size_t i) {
1608   AT_ASSERT(i < inputs_.size());
1609   auto input_node = inputs_[i];
1610   auto use_it = findUseForInput(i);
1611   input_node->uses_.erase(use_it);
1612   inputs_[i] = nullptr;
1613   return input_node;
1614 }
1615 
removeFromList()1616 void Node::removeFromList() {
1617   AT_ASSERT(inBlockList());
1618   this->owning_block_ = nullptr;
1619   Node* next = this->next();
1620   Node* prev = this->prev();
1621   prev->next() = next;
1622   next->prev() = prev;
1623   this->next() = nullptr;
1624   this->prev() = nullptr;
1625 }
1626 
findCommonAncestorBlockWith(Node * n)1627 Block* Node::findCommonAncestorBlockWith(Node* n) {
1628   if (n->owningBlock() == owningBlock()) {
1629     return owningBlock();
1630   }
1631 
1632   Node* n1 = this;
1633   Node* n2 = n;
1634 
1635   size_t d_1 = n1->blocksFromGraphBlock();
1636   size_t d_2 = n2->blocksFromGraphBlock();
1637 
1638   for (; d_1 > d_2; --d_1) {
1639     n1 = n1->owningBlock()->owningNode();
1640     // n2 contains n1
1641   }
1642 
1643   for (; d_2 > d_1; --d_2) {
1644     n2 = n2->owningBlock()->owningNode();
1645   }
1646 
1647   // Now they are the same numer of blocks from the graph block,
1648   // recurse upwards, checking if they are on the same block
1649   while (true) {
1650     if (n1->owningBlock() == n2->owningBlock()) {
1651       return n1->owningBlock();
1652     }
1653 
1654     n1 = n1->owningBlock()->owningNode();
1655     n2 = n2->owningBlock()->owningNode();
1656 
1657     AT_ASSERT(n1 != nullptr);
1658     AT_ASSERT(n2 != nullptr);
1659   }
1660 }
1661 
blocksFromGraphBlock()1662 size_t Node::blocksFromGraphBlock() {
1663   Node* n = this;
1664   size_t dist = 0;
1665   while (n->owningBlock()->owningNode()) {
1666     n = n->owningBlock()->owningNode();
1667     ++dist;
1668   }
1669   return dist;
1670 }
1671 
fakeRange()1672 inline const SourceRange& fakeRange() {
1673   static SourceRange range(std::make_shared<Source>(std::string("")), 0, 1);
1674   return range;
1675 }
1676 
insert(Symbol opname,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,const std::optional<SourceRange> & range)1677 Value* Graph::insert(
1678     Symbol opname,
1679     at::ArrayRef<NamedValue> args,
1680     at::ArrayRef<NamedValue> kwargs,
1681     const std::optional<SourceRange>& range) {
1682   return emitBuiltinCall(
1683       range.value_or(fakeRange()), *this, opname, args, kwargs);
1684 }
1685 
create(NodeKind kind,size_t num_outputs)1686 Node* Graph::create(NodeKind kind, size_t num_outputs) {
1687   // NB: Node constructor adds node to all_nodes
1688   auto n = new Node(this, kind);
1689   for (const auto i : c10::irange(num_outputs)) {
1690     (void)i;
1691     n->addOutput();
1692   }
1693   return n;
1694 }
1695 
create(NodeKind kind,ArrayRef<Value * > inputs,size_t num_outputs)1696 Node* Graph::create(
1697     NodeKind kind,
1698     ArrayRef<Value*> inputs,
1699     size_t num_outputs) {
1700   auto n = create(kind, num_outputs);
1701   for (auto i : inputs) {
1702     n->addInput(i);
1703   }
1704   return n;
1705 }
1706 
createAutogradZero()1707 Node* Graph::createAutogradZero() {
1708   return create(prim::AutogradZero);
1709 }
1710 
createNone()1711 Node* Graph::createNone() {
1712   Node* n = create(prim::Constant);
1713   n->output()->setType(NoneType::get());
1714   return n;
1715 }
1716 
createUninitialized(TypePtr typ)1717 Node* Graph::createUninitialized(TypePtr typ) {
1718   Node* n = create(prim::Uninitialized);
1719   n->output()->setType(std::move(typ));
1720   return n;
1721 }
1722 
createWithSubgraph(Symbol kind)1723 Node* Graph::createWithSubgraph(Symbol kind) {
1724   auto n = create(kind, 0);
1725   n->g_(attr::Subgraph, std::make_shared<Graph>(current_scope()));
1726   return n;
1727 }
1728 
createTuple(at::ArrayRef<Value * > values,TupleTypePtr tuple_type)1729 Node* Graph::createTuple(at::ArrayRef<Value*> values, TupleTypePtr tuple_type) {
1730   TORCH_INTERNAL_ASSERT(
1731       !tuple_type || tuple_type->schema(),
1732       "only pass tuple_type when creating a named tuple");
1733   if (!tuple_type) {
1734     auto types = fmap(values, [](Value* v) { return v->type(); });
1735     tuple_type = TupleType::create(std::move(types));
1736   }
1737   auto n = create(prim::TupleConstruct, values);
1738 
1739   n->output()->setType(tuple_type);
1740   return n;
1741 }
1742 
createTupleUnpack(Value * v)1743 Node* Graph::createTupleUnpack(Value* v) {
1744   TupleTypePtr tt = v->type()->expect<TupleType>();
1745   auto n = create(prim::TupleUnpack, {v}, 0);
1746   for (auto& element : tt->elements()) {
1747     n->addOutput()->setType(element);
1748   }
1749   return n;
1750 }
1751 
createTupleIndex(Value * tup,Value * idx,const TypePtr & output_type)1752 Node* Graph::createTupleIndex(
1753     Value* tup,
1754     Value* idx,
1755     const TypePtr& output_type) {
1756   auto n = create(prim::TupleIndex, {tup, idx});
1757   n->output()->setType(output_type);
1758   return n;
1759 }
1760 
createTupleSlice(Value * tup,int64_t beg,int64_t step_size,int64_t num_values)1761 Node* Graph::createTupleSlice(
1762     Value* tup,
1763     int64_t beg,
1764     int64_t step_size,
1765     int64_t num_values) {
1766   std::vector<Value*> new_vals;
1767   TupleTypePtr tt = tup->type()->expect<TupleType>();
1768   new_vals.reserve(num_values);
1769 
1770   int64_t i = beg;
1771   for (const auto j : c10::irange(num_values)) {
1772     (void)j; // Suppress unused variable warning
1773     auto idx = insertConstant(IValue(static_cast<int64_t>(i)));
1774     auto tupleIndex = insertNode(createTupleIndex(tup, idx, tt->elements()[i]));
1775 
1776     new_vals.push_back(tupleIndex->output());
1777     i += step_size;
1778   }
1779 
1780   auto n = createTuple(new_vals);
1781   return n;
1782 }
1783 
createEnumName(Value * e)1784 Node* Graph::createEnumName(Value* e) {
1785   e->type()->expect<EnumType>();
1786   assert(e->type()->cast<EnumType>());
1787   auto n = create(prim::EnumName, {e});
1788   n->output()->setType(StringType::get());
1789   return n;
1790 }
1791 
createEnumValue(Value * e)1792 Node* Graph::createEnumValue(Value* e) {
1793   auto enum_type = e->type()->expect<EnumType>();
1794   auto n = create(prim::EnumValue, {e});
1795   n->output()->setType(enum_type->getValueType());
1796   return n;
1797 }
1798 
createList(const TypePtr & contained_type,at::ArrayRef<Value * > values)1799 Node* Graph::createList(
1800     const TypePtr& contained_type,
1801     at::ArrayRef<Value*> values) {
1802   auto n = create(prim::ListConstruct, values);
1803   for (const auto& v : values) {
1804     TORCH_CHECK(
1805         v->type()->isSubtypeOf(*contained_type),
1806         "Expected a list element that subtypes '",
1807         contained_type->repr_str(),
1808         "' but got an element of type '",
1809         v->type()->repr_str(),
1810         "'");
1811   }
1812   n->output()->setType(ListType::create(contained_type));
1813   return n;
1814 }
1815 
createListUnpack(Value * v,size_t size)1816 Node* Graph::createListUnpack(Value* v, size_t size) {
1817   ListTypePtr list_type = v->type()->expect<ListType>();
1818   TypePtr elem_type = list_type->getElementType();
1819   auto n = create(prim::ListUnpack, {v}, 0);
1820   for (const auto i : c10::irange(size)) {
1821     (void)i; // Suppress unused variable warning
1822     n->addOutput()->setType(elem_type);
1823   }
1824   return n;
1825 }
1826 
createDict(const TypePtr & key_type,const TypePtr & value_type,at::ArrayRef<Value * > keys,at::ArrayRef<Value * > values)1827 Node* Graph::createDict(
1828     const TypePtr& key_type,
1829     const TypePtr& value_type,
1830     at::ArrayRef<Value*> keys,
1831     at::ArrayRef<Value*> values) {
1832   AT_ASSERT(keys.size() == values.size());
1833   auto n = create(prim::DictConstruct, 1);
1834   for (const auto i : c10::irange(keys.size())) {
1835     AT_ASSERT(keys[i]->type()->isSubtypeOf(*key_type));
1836     AT_ASSERT(values[i]->type()->isSubtypeOf(*value_type));
1837 
1838     n->addInput(keys[i]);
1839     n->addInput(values[i]);
1840   }
1841   n->output()->setType(DictType::create(key_type, value_type));
1842   return n;
1843 }
1844 
createNumToTensor(Value * value)1845 Node* Graph::createNumToTensor(Value* value) {
1846   Node* result = create(prim::NumToTensor, {value});
1847   result->output()->setType(TensorType::fromNumberType(*value->type()));
1848   return result;
1849 }
1850 
createObject(const ClassTypePtr & type)1851 Node* Graph::createObject(const ClassTypePtr& type) {
1852   auto result = create(prim::CreateObject);
1853   result->output()->setType(type);
1854   return result;
1855 }
1856 
createSetAttr(Value * obj,const std::string & field,Value * newValue)1857 Node* Graph::createSetAttr(
1858     Value* obj,
1859     const std::string& field,
1860     Value* newValue) {
1861   auto n = create(prim::SetAttr, {obj, newValue}, /*num_outputs=*/0);
1862   n->s_(attr::name, field);
1863   return n;
1864 }
1865 
createGetAttr(Value * obj,const std::string & field)1866 Node* Graph::createGetAttr(Value* obj, const std::string& field) {
1867   const auto classType = obj->type()->expect<ClassType>();
1868 
1869   auto n = create(prim::GetAttr, {obj}, /*num_outputs=*/1);
1870   n->s_(attr::name, field);
1871 
1872   const auto outputType = classType->getAttribute(field);
1873   n->output()->setType(outputType);
1874   n->output()->setDebugName(normalizeAttrName(field));
1875   return n;
1876 }
1877 
createStore(const std::string & name,Value * v)1878 Node* Graph::createStore(const std::string& name, Value* v) {
1879   auto n = create(prim::Store, {v}, /*num_outputs*/ 0);
1880   n->s_(attr::name, name);
1881   return n;
1882 }
1883 
createLoad(const std::string & name,const TypePtr & type)1884 Node* Graph::createLoad(const std::string& name, const TypePtr& type) {
1885   auto n = create(prim::Load, {}, /*num_outputs*/ 1);
1886   n->s_(attr::name, name);
1887   n->output()->setType(type);
1888   return n;
1889 }
1890 
createIsInstance(Value * v,at::ArrayRef<TypePtr> types)1891 Node* Graph::createIsInstance(Value* v, at::ArrayRef<TypePtr> types) {
1892   auto n = create(prim::isinstance, {v}, /*num_outputs*/ 1);
1893   n->tys_(attr::types, types.vec());
1894   n->output()->setType(BoolType::get());
1895   return n;
1896 }
insertUncheckedCast(Value * v,TypePtr type)1897 Value* Graph::insertUncheckedCast(Value* v, TypePtr type) {
1898   Node* n = insertNode(create(prim::unchecked_cast, {v}));
1899   n->output()->setType(std::move(type));
1900   return n->output();
1901 }
1902 
insertToList(Value * v,TypePtr type)1903 Value* Graph::insertToList(Value* v, TypePtr type) {
1904   int dim = 0;
1905   TypePtr ptr = type;
1906 
1907   // Unwrap the type to determine the number of dimensions.
1908   while (auto list_type = ptr->cast<ListType>()) {
1909     ptr = list_type->getElementType();
1910     ++dim;
1911   }
1912 
1913   // Encode the base element type as an integer.
1914   int elem_ty = 0;
1915   if (ptr == IntType::get()) {
1916     elem_ty = 0;
1917   } else if (ptr == FloatType::get()) {
1918     elem_ty = 1;
1919   } else if (ptr == BoolType::get()) {
1920     elem_ty = 2;
1921   } else if (ptr == ComplexType::get()) {
1922     elem_ty = 3;
1923   } else {
1924     TORCH_CHECK(
1925         false,
1926         ptr->repr_str(),
1927         " is not one of the supported element types for tolist: int, float, complex, bool");
1928   }
1929 
1930   // Pass in the number of dimensions and base element type as arguments
1931   // to the op.
1932   Value* dim_val = insertConstant(IValue(dim));
1933   Value* elem_ty_val = insertConstant(IValue(elem_ty));
1934   Node* n = insertNode(create(prim::tolist, {v, dim_val, elem_ty_val}));
1935   n->output()->setType(std::move(type));
1936   return n->output();
1937 }
1938 
insertFunctionCall(Function * callee,const MatchedSchema & matched)1939 Value* Graph::insertFunctionCall(
1940     Function* callee,
1941     const MatchedSchema& matched) {
1942   std::string func_name = callee->name();
1943   Value* fn_constant = insertNode(create(prim::Constant))
1944                            ->s_(attr::name, func_name)
1945                            ->output()
1946                            ->setType(FunctionType::create(callee));
1947   std::vector<Value*> inputs = {fn_constant};
1948   inputs.insert(inputs.end(), matched.inputs.begin(), matched.inputs.end());
1949   Value* result = insertNode(create(prim::CallFunction, inputs))
1950                       ->output()
1951                       ->setType(matched.return_types.at(0));
1952   return result;
1953 }
1954 
insertMethodCall(std::string method_name,const MatchedSchema & matched)1955 Value* Graph::insertMethodCall(
1956     std::string method_name,
1957     const MatchedSchema& matched) {
1958   Value* result = insertNode(create(prim::CallMethod, matched.inputs))
1959                       ->s_(attr::name, std::move(method_name))
1960                       ->output()
1961                       ->setType(matched.return_types.at(0));
1962   return result;
1963 }
1964 
createClone(Node * n,const std::function<Value * (Value *)> & value_map,bool copy_blocks)1965 Node* Graph::createClone(
1966     Node* n,
1967     const std::function<Value*(Value*)>& value_map,
1968     bool copy_blocks) {
1969   // n can be from a different graph
1970   Node* r = n->allocNewInstance(this);
1971   for (auto o : n->outputs()) {
1972     r->addOutput()->copyMetadata(o);
1973   }
1974   r->cloneFrom(n);
1975   for (auto i : n->inputs()) {
1976     r->addInput(value_map(i));
1977   }
1978   if (copy_blocks) {
1979     for (auto b : n->blocks()) {
1980       r->addBlock()->cloneFrom(b, value_map);
1981     }
1982   }
1983   return r;
1984 }
1985 
insertConstant(const IValue & val,std::optional<SourceRange> loc,std::optional<ScopePtr> scope)1986 Value* Graph::insertConstant(
1987     const IValue& val,
1988     std::optional<SourceRange> loc,
1989     std::optional<ScopePtr> scope) {
1990   return jit::insertConstant(*this, val, std::move(loc), std::move(scope));
1991 }
1992 
toString(bool print_source_locations) const1993 std::string Graph::toString(bool print_source_locations) const {
1994   std::ostringstream oss;
1995   print(oss, print_source_locations);
1996   return oss.str();
1997 }
1998 
~Graph()1999 Graph::~Graph() {
2000   for (const Node* n : all_nodes) {
2001     delete n;
2002   }
2003   for (const Value* v : all_values) {
2004     delete v;
2005   }
2006   for (const Block* b : all_blocks) {
2007     delete b;
2008   }
2009 }
2010 
freeNode(Node * n)2011 void Graph::freeNode(Node* n) {
2012   auto it = all_nodes.find(n);
2013   AT_ASSERT(it != all_nodes.end());
2014   delete *it;
2015   all_nodes.erase(it);
2016 }
freeValue(Value * v)2017 void Graph::freeValue(Value* v) {
2018   v->setDebugName("");
2019   auto it = all_values.find(v);
2020   AT_ASSERT(it != all_values.end());
2021   delete *it;
2022   all_values.erase(it);
2023 }
freeBlock(Block * b)2024 void Graph::freeBlock(Block* b) {
2025   auto it = all_blocks.find(b);
2026   AT_ASSERT(it != all_blocks.end());
2027   delete *it;
2028   all_blocks.erase(it);
2029 }
2030 
createTupleUnpack(Value * v)2031 at::ArrayRef<Value*> createTupleUnpack(Value* v) {
2032   // small peephole optimization to ensure IntArrayRef attributes can still turn
2033   // into constants e.g. in x.expand([3, 4])
2034   if (v->node()->kind() == prim::TupleConstruct) {
2035     return v->node()->inputs();
2036   }
2037   auto& g = *v->owningGraph();
2038   return g.insertNode(g.createTupleUnpack(v))->outputs();
2039 }
2040 
2041 void inlineCallStackOfNode(
2042     Node* n,
2043     std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>& new_cs_entries,
2044     Function* callee,
2045     Node* to_replace,
2046     const std::optional<ModuleInstanceInfo>& m_info);
2047 
inlineCallStackOfBlock(Block * b,std::unordered_map<InlinedCallStack *,InlinedCallStackPtr> & new_cs_entries,Function * callee,Node * to_replace,const std::optional<ModuleInstanceInfo> & m_info)2048 static void inlineCallStackOfBlock(
2049     Block* b,
2050     std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>& new_cs_entries,
2051     Function* callee,
2052     Node* to_replace,
2053     const std::optional<ModuleInstanceInfo>& m_info) {
2054   for (auto n : b->nodes()) {
2055     inlineCallStackOfNode(n, new_cs_entries, callee, to_replace, m_info);
2056   }
2057 }
2058 
inlineCallStackOfNode(Node * new_node,std::unordered_map<InlinedCallStack *,InlinedCallStackPtr> & new_cs_entries,Function * callee,Node * to_replace,const std::optional<ModuleInstanceInfo> & m_info)2059 void inlineCallStackOfNode(
2060     Node* new_node,
2061     std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>& new_cs_entries,
2062     Function* callee,
2063     Node* to_replace,
2064     const std::optional<ModuleInstanceInfo>& m_info) {
2065   auto new_node_cs = new_node->callstack();
2066 
2067   InlinedCallStack* raw_callstack_ptr =
2068       new_node_cs ? new_node_cs->get() : nullptr;
2069 
2070   if (!new_cs_entries.count(raw_callstack_ptr)) {
2071     if (new_node_cs) {
2072       new_cs_entries[raw_callstack_ptr] = c10::make_intrusive<InlinedCallStack>(
2073           *new_node_cs, callee, to_replace->sourceRange(), m_info);
2074     } else {
2075       new_cs_entries[raw_callstack_ptr] = c10::make_intrusive<InlinedCallStack>(
2076           callee, to_replace->sourceRange(), m_info);
2077     }
2078   }
2079   new_node->setCallStack(new_cs_entries.at(raw_callstack_ptr));
2080   // We updated the inlined callstack of new_node.
2081   // Same must be done for the nodes of the blocks of new_node.
2082   // For example If node's block otherwise is not annotated appropriately.
2083   for (auto block : new_node->blocks()) {
2084     inlineCallStackOfBlock(block, new_cs_entries, callee, to_replace, m_info);
2085   }
2086 }
2087 
inlineCallTo(Node * to_replace,GraphFunction * callee,Graph * callee_graph)2088 std::vector<Value*> inlineCallTo(
2089     Node* to_replace,
2090     GraphFunction* callee,
2091     Graph* callee_graph) {
2092   WithInsertPoint guard(to_replace);
2093   std::unordered_map<Value*, Value*> value_map;
2094   std::vector<torch::jit::Value*> new_outputs = insertGraph(
2095       *to_replace->owningGraph(),
2096       *callee_graph,
2097       to_replace->inputs(),
2098       value_map);
2099 
2100   std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>
2101       new_callstack_entries;
2102 
2103   std::optional<ModuleInstanceInfo> module_instance_info = std::nullopt;
2104   if (to_replace->kind() == prim::CallMethod) {
2105     auto class_type_ptr = to_replace->input(0)->type()->cast<c10::ClassType>();
2106     if (to_replace->input(0)->node()->kind() == prim::GetAttr) {
2107       module_instance_info = std::make_optional(ModuleInstanceInfo(
2108           class_type_ptr, to_replace->input(0)->node()->s(attr::name)));
2109     } else if (
2110         !to_replace->owningGraph()->inputs().empty() &&
2111         to_replace->input(0) == to_replace->owningGraph()->inputs()[0]) {
2112       // This CallMethod must correspond to method of the same object
2113       // to which this graph belongs.
2114       module_instance_info =
2115           std::make_optional(ModuleInstanceInfo(class_type_ptr, "SELF"));
2116     } else {
2117       // Not sure if it is possible to come here ever.
2118       // TODO: Remove this else. Or add assert
2119       module_instance_info = std::make_optional(
2120           ModuleInstanceInfo(class_type_ptr, "INSTANCE_NAME_UNKNOWN"));
2121     }
2122   }
2123 
2124   // TODO: We might need to use nodes_map instead of value_map. Otherwise, we
2125   // are missing nodes without outputs (e.g. prim::Print).
2126   std::unordered_set<Node*> updated_nodes;
2127   for (const auto& kv : value_map) {
2128     /* Skip the old value if it is the graph input.
2129      * The reason is that, value_map contains values not all for the nodes of
2130      * the graph but primary inputs as well, and it will create duplicates when
2131      * the first inlined graph is input to the next one. To avoid this issue,
2132      * skip the old value when it is one of the
2133      * callee->optimized_graph()->inputs() or callee->graph()->inputs(), depends
2134      * on if it is inlined_optimized_graph
2135      */
2136     auto is_graph_input = std::find(
2137         callee_graph->inputs().begin(), callee_graph->inputs().end(), kv.first);
2138     if (is_graph_input != callee_graph->inputs().end()) {
2139       continue;
2140     }
2141 
2142     Node* new_node = kv.second->node();
2143     if (!updated_nodes.insert(new_node).second) {
2144       continue;
2145     }
2146 
2147     inlineCallStackOfNode(
2148         new_node,
2149         new_callstack_entries,
2150         callee,
2151         to_replace,
2152         module_instance_info);
2153   }
2154   const auto& old_outputs = to_replace->outputs();
2155 
2156   AT_ASSERT(new_outputs.size() == old_outputs.size());
2157   for (const auto i : c10::irange(old_outputs.size())) {
2158     if (old_outputs[i]->hasDebugName()) {
2159       new_outputs[i]->setDebugName(old_outputs[i]->debugName());
2160     }
2161     old_outputs[i]->replaceAllUsesWith(new_outputs[i]);
2162   }
2163   to_replace->destroy();
2164 
2165   return new_outputs;
2166 }
2167 
2168 // inline_optimized_graph argument is used in substitute function call for
2169 // ONNX conversion
inlineCallTo(Node * to_replace,GraphFunction * callee,bool inline_optimized_graph)2170 std::vector<Value*> inlineCallTo(
2171     Node* to_replace,
2172     GraphFunction* callee,
2173     bool inline_optimized_graph /*=true*/) {
2174   auto graph =
2175       inline_optimized_graph ? callee->optimized_graph() : callee->graph();
2176   return inlineCallTo(to_replace, callee, graph.get());
2177 }
2178 
unpackOutputs(const std::vector<Value * > & outputs)2179 std::vector<Value*> unpackOutputs(const std::vector<Value*>& outputs) {
2180   std::vector<Value*> new_outputs;
2181   if (outputs.size() != 1 || outputs.at(0)->type()->kind() != TupleType::Kind) {
2182     return outputs;
2183   }
2184 
2185   auto tup = outputs[0];
2186   for (Value* v : createTupleUnpack(tup)) {
2187     new_outputs.emplace_back(v);
2188   }
2189   // if this was a peephole tuple unpack we can just get rid of
2190   // the tuple construct here and prevent needing DCE
2191   if (tup->node()->kind() == prim::TupleConstruct && !tup->node()->hasUses()) {
2192     tup->node()->destroy();
2193   }
2194   return new_outputs;
2195 }
2196 
findAllNodes(at::ArrayRef<Block * > array,Symbol kind,bool recurse)2197 std::vector<Node*> findAllNodes(
2198     at::ArrayRef<Block*> array,
2199     Symbol kind,
2200     bool recurse) {
2201   std::vector<Node*> ret;
2202   for (auto block : array) {
2203     findAllNodes(*block, kind, recurse, ret);
2204   }
2205   return ret;
2206 }
2207 
findAllNodes(Block & block,Symbol kind,bool recurse)2208 std::vector<Node*> findAllNodes(Block& block, Symbol kind, bool recurse) {
2209   return findAllNodes({&block}, kind, recurse);
2210 }
2211 
findAllNodes(Graph & g,Symbol kind,bool recurse)2212 std::vector<Node*> findAllNodes(Graph& g, Symbol kind, bool recurse) {
2213   return findAllNodes(*g.block(), kind, recurse);
2214 }
2215 
insertGraph(Graph & g,Graph & callee,ArrayRef<Value * > inputs,std::unordered_map<Value *,Value * > & value_map)2216 std::vector<Value*> insertGraph(
2217     Graph& g,
2218     Graph& callee,
2219     ArrayRef<Value*> inputs,
2220     std::unordered_map<Value*, Value*>& value_map) {
2221   auto value_map_func = [&](Value* v) { return value_map.at(v); };
2222   AT_ASSERT(callee.inputs().size() == inputs.size());
2223   for (const auto i : c10::irange(inputs.size())) {
2224     value_map[callee.inputs()[i]] = inputs[i];
2225   }
2226   for (auto* node : callee.nodes()) {
2227     auto* new_node = g.insertNode(g.createClone(node, value_map_func));
2228     for (size_t i = 0; i < node->outputs().size(); ++i) {
2229       value_map[node->outputs()[i]] = new_node->outputs()[i];
2230     }
2231   }
2232 
2233   std::vector<Value*> outputs;
2234   for (auto* output : callee.outputs()) {
2235     outputs.push_back(value_map_func(output));
2236   }
2237 
2238   return outputs;
2239 }
2240 
insertGraph(Graph & g,Graph & callee,ArrayRef<Value * > inputs)2241 std::vector<Value*> insertGraph(
2242     Graph& g,
2243     Graph& callee,
2244     ArrayRef<Value*> inputs) {
2245   std::unordered_map<Value*, Value*> value_map;
2246   return insertGraph(g, callee, inputs, value_map);
2247 }
2248 
cloneFrom(Node * other_)2249 void ProfileOp::cloneFrom(Node* other_) {
2250   Node::cloneFrom(other_);
2251   auto other = other_->cast<ProfileOp>();
2252   this->callback_ = other->getCallback();
2253 }
2254 
allocNewInstance(Graph * g)2255 Node* ProfileOp::allocNewInstance(Graph* g) {
2256   return new ProfileOp(g, {nullptr});
2257 }
2258 
cloneFrom(Node * other_)2259 void ProfileIValueOp::cloneFrom(Node* other_) {
2260   Node::cloneFrom(other_);
2261   auto other = other_->cast<ProfileIValueOp>();
2262   this->callback_ = other->getCallback();
2263 }
2264 
allocNewInstance(Graph * g)2265 Node* ProfileIValueOp::allocNewInstance(Graph* g) {
2266   return new ProfileIValueOp(g, {nullptr});
2267 }
2268 
type() const2269 TypePtr NamedValue::type() const {
2270   if (value_) {
2271     return value_->type();
2272   } else {
2273     return ivalue_.type();
2274   }
2275 }
2276 
2277 const Symbol ProfileOp::Kind = ::c10::prim::profile;
2278 const Symbol ProfileIValueOp::Kind = ::c10::prim::profile_ivalue;
2279 
OperatorSet(std::initializer_list<const char * > sig_literals)2280 OperatorSet::OperatorSet(std::initializer_list<const char*> sig_literals) {
2281   insert(sig_literals);
2282 }
2283 
getOps() const2284 std::vector<std::shared_ptr<Operator>> OperatorSet::getOps() const {
2285   std::vector<std::shared_ptr<Operator>> result;
2286   for (const auto& kv : ops) {
2287     auto ops_for_symbol = kv.second;
2288     result.insert(result.end(), ops_for_symbol.begin(), ops_for_symbol.end());
2289   }
2290   return result;
2291 }
2292 
insert(std::initializer_list<const char * > sig_literals)2293 void OperatorSet::insert(std::initializer_list<const char*> sig_literals) {
2294   for (const char* sig : sig_literals) {
2295     auto op = getOperatorForLiteral(sig);
2296     ops[Symbol::fromQualString(op->schema().name())].push_back(op);
2297   }
2298 }
2299 
isMemberOf(const OperatorSet & os) const2300 bool Node::isMemberOf(const OperatorSet& os) const {
2301   auto it = os.ops.find(kind());
2302   if (it == os.ops.end()) {
2303     return false;
2304   }
2305   for (auto& op : it->second) {
2306     if (matches(op->schema())) {
2307       return true;
2308     }
2309   }
2310   return false;
2311 }
2312 
2313 } // namespace torch::jit
2314