xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/ir/ir.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/ir/attributes.h>
4 #include <torch/csrc/jit/ir/graph_node_list.h>
5 #include <torch/csrc/jit/ir/named_value.h>
6 #include <torch/csrc/jit/ir/scope.h>
7 #include <torch/csrc/jit/runtime/operator.h>
8 
9 #include <torch/csrc/Export.h>
10 #include <torch/csrc/utils/python_stub.h>
11 #include <torch/csrc/utils/schema_info.h>
12 
13 #include <ATen/Utils.h>
14 #include <ATen/core/Tensor.h>
15 #include <ATen/core/dynamic_type.h>
16 #include <ATen/core/enum_type.h>
17 #include <ATen/core/functional.h>
18 #include <ATen/core/interned_strings.h>
19 #include <ATen/core/ivalue.h>
20 #include <ATen/core/jit_type.h>
21 #include <c10/util/ArrayRef.h>
22 #include <c10/util/Exception.h>
23 #include <optional>
24 
25 #include <functional>
26 #include <iosfwd>
27 #include <unordered_set>
28 #include <vector>
29 
30 // Forward declare, the real meat is in python_ir.cpp
31 template <class T>
32 class THPPointer;
33 using THPObjectPtr = THPPointer<PyObject>;
34 using pyobj_list = std::vector<THPObjectPtr>;
35 
36 namespace torch::jit {
37 namespace utils {
38 TORCH_API std::string getNodesModuleHierarchy(const Node& n);
39 } // namespace utils
40 class AliasDb;
41 
42 using ::c10::Argument;
43 using ::c10::FunctionSchema;
44 using ::c10::Symbol;
45 
46 using ::c10::ivalue::Shared;
47 
48 using ::c10::IValue;
49 using ::c10::ivalue::Future;
50 
51 using ::c10::ivalue::ConstantString;
52 
53 #define C10_USING(T) using ::c10::T;
54 C10_FORALL_TYPES(C10_USING)
55 #undef C10_USING
56 
57 #define C10_USING(T) using ::c10::T##Ptr;
58 C10_FORALL_TYPES(C10_USING)
59 #undef C10_USING
60 
61 using ::c10::Type;
62 using ::c10::TypeEnv;
63 using ::c10::TypePtr;
64 
65 using ::c10::getTypePtr;
66 using ::c10::MatchTypeReturn;
67 using ::c10::TypeKind;
68 
69 using ::c10::fmap;
70 
71 namespace prim {
72 using namespace ::c10::prim;
73 }
74 namespace attr {
75 using namespace ::c10::attr;
76 }
77 namespace aten {
78 using namespace ::c10::aten;
79 }
80 namespace cuda {
81 #if !defined(USE_ROCM)
82 using namespace ::c10::cuda;
83 #endif
84 } // namespace cuda
85 
86 struct Function;
87 struct GraphFunction;
88 struct MatchedSchema;
89 
90 // A Graph represents one "function" of computation.
91 // It uses a simple ownership model where the graph owns all the nodes inside
92 // it. All references inside the graph are raw pointers. Destroying the Graph
93 // will invalidate any pointers to nodes in the graph.
94 struct Graph;
95 
96 // Node is the base class of the IR graph. It represents one computation
97 // and dependencies on a list of Values. The "prim-ops", so to speak.
98 struct Node;
99 
100 // A Value represents an input or output to node that is either a
101 // Tensor or an opaque Handle object, as determined by type().
102 struct Value;
103 
104 TORCH_API std::ostream& operator<<(std::ostream& out, const Graph& g);
105 TORCH_API std::ostream& operator<<(std::ostream& out, const Node& n);
106 
107 // A list of nodes, with inputs and outputs
108 struct Block;
109 
110 // Each use is represented by this type, see 'Node::uses()'
111 // 'user' is the consumer of the value, 'offset' is the index into
112 // 'user's input this where the producers will be found.
113 struct Use {
UseUse114   Use(Node* user, size_t offset) : user(user), offset(offset) {}
115   Node* user;
116   size_t offset;
117 
118   bool operator==(const Use& b) {
119     return user == b.user && offset == b.offset;
120   }
121 };
122 
123 // Note [User node does not uniquely identify use]
124 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
125 // A while back, we wrote some code manipulating uses that looked like this:
126 //
127 //    for (auto& use : used_val->uses_) {
128 //      if (use.user == this_node) {
129 //        use.offset += 1;
130 //        break;
131 //      }
132 //    }
133 //
134 // This code is trying to find a particular use (our node's use) to update it.
135 // However, it's wrong: there may be *multiple* uses of a value %x in a node,
136 // as might be the case in this IR:
137 //
138 //    %y = Add %x %x
139 //
140 // In this case, there are two uses of %x whose user is the node 'Add %x %x'.
141 // So, "use induced by this node" is not a well-formed concept.
142 //
143 // If you are looking for "use induced by an input", it's best to use
144 // findUseForInput() to get it.
145 
146 // the list types are intentionally simple, but we type-def
147 // them here so if we need to change them, refactoring will be easier
148 using node_list = std::vector<Node*>;
149 using value_list = std::vector<Value*>;
150 using use_list = std::vector<Use>;
151 template <typename T>
152 using ArrayRef = at::ArrayRef<T>;
153 using NodeKind = Symbol;
154 using topo_position_t = int64_t;
155 using ValueSet = std::unordered_set<const Value*>;
156 
157 struct OperatorSet;
158 template <typename T>
159 struct OperatorMap;
160 
161 // This is a wrapper to allow invalidating the Python object
162 // safely when the C++ object for a Node/Value/Block is deleted
163 // like much of graph, it isn't safe for different threads to
164 // access the same graph
165 template <typename T>
166 struct Wrap {
WrapWrap167   explicit Wrap(T* p) : elem(p) {}
clearWrap168   void clear() {
169     if (clear_cb) {
170       clear_cb(elem);
171     }
172     elem = nullptr;
173   }
174   T* elem;
175   void (*clear_cb)(void*){nullptr};
176 };
177 
178 struct Value {
179   AT_DISALLOW_COPY_AND_ASSIGN(Value);
180   Value(Node* node_, size_t offset_);
181 
182  private:
183   friend struct Node;
184   friend struct Graph;
185   Node* node_;
186   size_t offset_;
187   size_t unique_ = 0; // unique id
188   use_list uses_;
189   std::string unique_name_;
190   TypePtr type_;
191   // a managing wrapper for Python to allow invalidation
192   std::shared_ptr<Wrap<Value>> wrap_;
193 
194  public:
195   Value* setType(TypePtr type);
196   TORCH_API void inferTypeFrom(const at::Tensor& output);
197   TORCH_API void inferTypeFrom(
198       const c10::intrusive_ptr<c10::ivalue::Object>& output);
typeValue199   const TypePtr& type() const {
200     AT_ASSERT(type_ != nullptr);
201     return type_;
202   }
requires_gradValue203   bool requires_grad() const {
204     return type()->requires_grad();
205   }
isCompleteTensorValue206   bool isCompleteTensor() const {
207     if (auto pt = type()->cast<TensorType>()) {
208       return pt->isComplete();
209     }
210     return false;
211   }
212   TORCH_API bool mustBeNone() const;
213   TORCH_API bool mustNotBeNone() const;
uniqueValue214   size_t unique() const {
215     return unique_;
216   }
hasDebugNameValue217   bool hasDebugName() const {
218     return !unique_name_.empty();
219   }
220   static bool isValidName(const std::string& name);
221   TORCH_API Value* setDebugName(const std::string& name);
debugNameValue222   std::string debugName() const {
223     if (hasDebugName()) {
224       return unique_name_;
225     }
226     return std::to_string(unique());
227   }
228   TORCH_API std::string debugNameBase() const;
nodeValue229   Node* node() {
230     return node_;
231   }
offsetValue232   size_t offset() const {
233     return offset_;
234   }
setOffsetValue235   void setOffset(size_t offset) {
236     offset_ = offset;
237   }
nodeValue238   const Node* node() const {
239     return node_;
240   }
241 
242   /**
243    * @warning NEVER pass raw pointer of smart pointer managed Graph to Python.
244    * Check #87343 for details.
245    */
246   Graph* owningGraph();
247   const Graph* owningGraph() const;
248   // TODO: make this more const correct
usesValue249   const use_list& uses() const {
250     return uses_;
251   }
252 
hasUsesValue253   bool hasUses() const {
254     return !uses().empty();
255   }
256 
257   TORCH_API void replaceFirstUseWith(Value* newValue);
258 
259   // Replaces all uses of this value with 'newValue'.
260   //
261   // Given:   %3 = f(%1, %2)
262   //          %4 = g(%3)
263   //          %5 = h(%3, %3)
264   // Execute: %3.replaceAllUsesWith(%6)
265   // Result:  %3 = f(%1, %2)
266   //          %4 = g(%6)
267   //          %5 = h(%6, %6)
268   TORCH_API void replaceAllUsesWith(Value* newValue);
269 
270   // Replaces all uses of this value with 'newValue' after 'node'.
271   // Given:   %3 = f(%1, %2)
272   //          %4 = g(%3)
273   //          %5 = inplace_(%3)
274   //          %6 = h(%3, %3)
275   // Execute: %3.replaceAllUsesAfterNodeWith(%5.node(), %5)
276   // Result:  %3 = f(%1, %2)
277   //          %4 = g(%3)
278   //          %5 = inplace_(%3)
279   //          %6 = h(%5, %5)
280   // XXX: does not check scoping legality, consider using
281   // replaceAllUsesDominatedByNodeWith
282   TORCH_API void replaceAllUsesAfterNodeWith(const Node* node, Value* newValue);
283 
284   // Replaces all uses of this value with 'newValue' that are dominated by
285   // 'node'. Given:
286   // x = op(...).
287   // if cond:
288   //    z = foo(..)
289   //    bar(x)
290   // else:
291   //    print(x)
292   // x.replaceAllUsesDominatedByNodeWith(foo, z) would replace bar(x)
293   // but not print(x) because print is not dominated by foo.
294   // replaceAllUsesAfterNode does not check domination, so in this example
295   // it would produce invalid IR.
296   TORCH_API void replaceAllUsesDominatedByNodeWith(
297       const Node* node,
298       Value* newValue);
299 
300   TORCH_API Value* copyMetadata(Value* from);
301 
wrapValue302   TORCH_API std::shared_ptr<Wrap<Value>> wrap() {
303     if (!wrap_) {
304       wrap_ = std::make_shared<Wrap<Value>>(this);
305     }
306     return wrap_;
307   }
308 
~ValueValue309   virtual ~Value() {
310     if (wrap_) {
311       wrap_->clear();
312     }
313   }
314 };
315 
316 struct TORCH_API Node {
317   AT_DISALLOW_COPY_AND_ASSIGN(Node);
318   friend struct Graph;
319   friend struct Block;
320   friend struct Value;
321   friend graph_node_list;
322   friend const_graph_node_list;
323   friend graph_node_list_iterator;
324   friend const_graph_node_list_iterator;
325 
326  private:
327   const NodeKind kind_;
328   std::vector<Value*> inputs_;
329   std::vector<Value*> outputs_;
330   // subblocks
331   std::vector<Block*> blocks_;
332   Graph* graph_;
333   Block* owning_block_;
334   std::optional<SourceRange> source_range_;
335   ScopePtr scope_;
336   std::optional<InlinedCallStackPtr> callstack_;
337   // Assumes FunctionSchemas are persistent, so we don't manage their lifetime.
338   // This field is effective a cache that's populated on attribute lookups and
339   // invalidated every time we perform an operation that could potentially
340   // change the schema. note: mutable because schema_ is effectively a cache
341   mutable const Operator* op_;
342   topo_position_t topo_position_ = 0;
343   // a managing wrapper for Python to allow invalidation
344   std::shared_ptr<Wrap<Node>> wrap_;
345   // Stores the full schema name, if the operator is historic
346   // When the operator is deprecated or the name of the operator
347   // is changed, we need to rely on this name
348   // to retrieve old schemas to successfully apply upgraders
349   // for this operator.
350   std::optional<std::string> historic_schema_name_ = std::nullopt;
351 
352  protected:
353   Node(Graph* graph_, NodeKind kind_); // defined after graph
354  public:
355   // Each Node but Return/Param Nodes are associated with exactly one
356   // place in the Node list of the Graph. The Graph itself is a circular
357   // doubly-linked list. The Return Node is used as the sentinel for the
358   // "beginning"/"end" of the list. This means that you can tell when
359   // you've traversed the entire list without means worrying about null
360   // pointers. `next_in_graph[0]` is the pointer to the next Node, while
361   // `next_in_graph[1]` is the pointer to the previous Node. The
362   // linked list is implemented as an array to allow the same iterator
363   // class for forward and reversed Node lists. Taken together, this
364   // list also represents a topological sort of the Nodes in the Graph.
365   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-non-private-member-variables-in-classes,modernize-avoid-c-arrays)
366   Node* next_in_graph[2] = {nullptr, nullptr};
367 
wrapNode368   std::shared_ptr<Wrap<Node>> wrap() {
369     if (!wrap_) {
370       wrap_ = std::make_shared<Wrap<Node>>(this);
371     }
372     return wrap_;
373   }
374 
getHistoricSchemaNameNode375   const std::optional<std::string> getHistoricSchemaName() {
376     return historic_schema_name_;
377   }
378 
setHistoricSchemaNameNode379   void setHistoricSchemaName(const std::string& name) {
380     historic_schema_name_ = name;
381   }
382 
nextNode383   Node*& next() {
384     return next_in_graph[kNextDirection];
385   }
prevNode386   Node*& prev() {
387     return next_in_graph[kPrevDirection];
388   }
nextNode389   Node* const& next() const {
390     return next_in_graph[kNextDirection];
391   }
prevNode392   Node* const& prev() const {
393     return next_in_graph[kPrevDirection];
394   }
395 
kindNode396   NodeKind kind() const {
397     return kind_;
398   }
setSourceRangeNode399   Node* setSourceRange(SourceRange r) {
400     source_range_ = std::move(r);
401     return this;
402   }
403   SourceRange sourceRange() const;
404 
405   /**
406    * @warning NEVER pass raw pointer of smart pointer managed Graph to Python.
407    * Check #87343 for details.
408    */
owningGraphNode409   Graph* owningGraph() {
410     return graph_;
411   }
owningGraphNode412   const Graph* owningGraph() const {
413     return graph_;
414   }
owningBlockNode415   Block* owningBlock() {
416     return owning_block_;
417   }
owningBlockNode418   const Block* owningBlock() const {
419     return owning_block_;
420   }
scopeNode421   ScopePtr scope() {
422     return scope_;
423   }
setScopeNode424   void setScope(ScopePtr scope) {
425     scope_ = std::move(scope);
426   }
scopeNameNode427   std::string scopeName() const {
428     if (!scope_) {
429       return "";
430     }
431     return scope_->namesFromRoot();
432   }
433 
434   // Copies the source range, scope and callstack from another node.
copyMetadataNode435   Node* copyMetadata(Node* from) {
436     this->setSourceRange(from->sourceRange());
437     this->setScope(from->scope());
438     if (auto cs = from->callstack()) {
439       this->setCallStack(*cs);
440     }
441     return this;
442   }
443 
callstackNode444   std::optional<InlinedCallStackPtr> callstack() const {
445     return callstack_;
446   }
setCallStackNode447   void setCallStack(InlinedCallStackPtr cs) {
448     callstack_ = std::move(cs);
449   }
450 
451   // NB: This returns an ArrayRef; that means that it will
452   // get invalidated if you resize inputs (e.g., using addInput)
453   // We can't return a std::vector<Node*>& because there's no
454   // way to soundly cast to std::vector<const Node*> (an insane
455   // implementation of std::vector could make this representationally
456   // different.)
inputsNode457   at::ArrayRef<Value*> inputs() {
458     return inputs_;
459   }
inputsNode460   at::ArrayRef<const Value*> inputs() const {
461     // Vectors are not convertible in const-ness of elements, but
462     // raw pointers are.
463     return {inputs_.data(), inputs_.size()};
464   }
465   // NB: This returns an ArrayRef; that means that it will
466   // get invalidated if you resize inputs (e.g., using addInput)
467   // We can't return a std::vector<Node*>& because there's no
468   // way to soundly cast to std::vector<const Node*> (an insane
469   // implementation of std::vector could make this representationally
470   // different.)
outputsNode471   at::ArrayRef<Value*> outputs() {
472     return outputs_;
473   }
outputsNode474   at::ArrayRef<const Value*> outputs() const {
475     // Vectors are not convertible in const-ness of elements, but
476     // raw pointers are.
477     return {outputs_.data(), outputs_.size()};
478   }
outputNode479   Value* output(size_t i) const {
480     return outputs_.at(i);
481   }
hasUsesNode482   bool hasUses() const {
483     for (auto o : outputs()) {
484       if (!o->uses().empty()) {
485         return true;
486       }
487     }
488     return false;
489   }
490 
491   void replaceAllUsesWith(Node* n);
492 
493   // replaces `this` with a new node with the same inputs and outputs
494   // but a new node symbol. does not destroy `this`
495   Node* replaceWithNewSymbol(Symbol new_symbol);
496 
497   // Checks if this node is dominated by `dominator` which means that
498   // `dominator` will always be executed before `this` and `dominator`
499   // is in scope of `this.
500   bool isDominatedBy(const Node* dominator) const;
501 
502   // lots of things like chunk have a single input or single output, so we have
503   // a helper to make accessing it easier
inputNode504   Value* input() {
505     AT_ASSERT(inputs_.size() == 1);
506     return inputs_.at(0);
507   }
outputNode508   Value* output() {
509     AT_ASSERT(outputs_.size() == 1);
510     return outputs_.at(0);
511   }
outputNode512   const Value* output() const {
513     AT_ASSERT(outputs_.size() == 1);
514     return outputs_.at(0);
515   }
inputNode516   const Value* input() const {
517     AT_ASSERT(inputs_.size() == 1);
518     return inputs_.at(0);
519   }
520   // Access a particular input.  This is a checked index.
inputNode521   Value* input(size_t i) const {
522     return inputs_.at(i);
523   }
524 
525   bool hasNamedInput(const std::string& unqualName) const;
526   Value* namedInput(const std::string& unqualName) const;
527   Value* namedInput(Symbol name) const;
528 
529   std::optional<IValue> get(Symbol name) const;
530 
531   template <typename T>
getNode532   std::optional<T> get(Symbol name) const {
533     if (auto v = get(name)) {
534       return v->template to<T>();
535     }
536     return std::nullopt;
537   }
538 
539   // Returns true if the value of input name is statically known
is_constantNode540   bool is_constant(Symbol name) const {
541     return static_cast<bool>(get(name));
542   }
543   bool mustBeNone() const;
544 
545   bool isNondeterministic() const;
546   bool hasSideEffects() const;
547 
548   // instructions lowered by the interpreter and not run in the optimized graph
notExecutedOpNode549   bool notExecutedOp() const {
550     return kind_ == prim::Constant || kind_ == prim::profile ||
551         kind_ == prim::profile_ivalue;
552   }
553 
554   // Graphs
555 
556   // Note [Topological invariant]
557   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
558   // We always maintain an up-to-date topological ordering of all nodes via
559   // the next()/prev() links.  All transformations to graphs must preserve
560   // this topological ordering: for example, it is only valid to 'addInput'
561   // with an input which is topologically before the current node.
562   //
563   // Usually, it is obvious whether or not topological order is maintained;
564   // for example, if you are adding nodes to the end of the topsort, it's
565   // impossible for them to refer to inputs that are not in the topsort.
566   // If it is not obvious, please comment accordingly.
567 
568   // Add 'node' as an input to 'this' at the end of existing
569   // arguments.  Returns the added node for ease of chaining.
570   //
571   // Given:   %3 = f(%1, %2)
572   // Execute: %3.addInput(%4)
573   // Result:  %3 = f(%1, %2, %4)
574   Value* addInput(Value* value);
575 
576   // Add 'value' as an input to 'this' at the specified position in the
577   // arguments. Returns the added value for ease of chaining.
578   Value* insertInput(size_t i, Value* value);
579 
580   // Replace the input of 'this' at position 'i' with
581   // 'newValue', returning the old node.
582   //
583   // Given:   %3 = f(%1, %2)
584   // Execute: %3.replaceInput(1, %4)
585   // Result:  %3 = f(%1, %4)
586   Value* replaceInput(size_t i, Value* newValue);
587 
588   // Replace all occurrences of 'from' in the inputs of this
589   // node with 'to'. Corresponds to llvm's replaceUsesOfWith.
590   //
591   // Given:   %3 = f(%1, %2, %1)
592   // Execute: %3.replaceInputWith(%1, %4)
593   // Result:  %3 = f(%4, %2, %4)
594   void replaceInputWith(Value* from, Value* to);
595 
596   Value* addOutput();
597 
598   Value* insertOutput(size_t i);
599 
600   void eraseOutput(size_t i);
601 
602   Block* addBlock();
603   void eraseBlock(size_t i);
604 
605   // Each Node can have a list of subblocks. These are used to define structured
606   // nested control flow operators such as If and Loop.
607   // The meaning of a block is specific to the kind of node it is in, but
608   // all blocks share these semantics:
609   // * Nested lexical scoping: If a node 'Parent' has a subblock which contains
610   //   a node 'Child', Child can use any value that was in scope for the Parent
611   //   node in addition to any values defined before 'Child' in the subblock.
612   // * The list of inputs to the block are in scope for the duration of the
613   //   block
614   // * the outputs of the Parent node are not in scope for the subblocks
615   // Typically the inputs to a block that represents control flow act as
616   // as the equivalents phi-nodes in standard SSA form,
617   // defining a new Value to represent any term that has multiple
618   // definitions depending on how control flowed. Outputs of the node containing
619   // control flow serve a similiar purpose defining new values for variables
620   // that would have different definitions depending on which way control
621   // flowed.
622 
blocksNode623   at::ArrayRef<Block*> blocks() {
624     return blocks_;
625   }
blocksNode626   at::ArrayRef<const Block*> blocks() const {
627     // Vectors are not convertible in const-ness of elements, but
628     // raw pointers are.
629     return {blocks_.data(), blocks_.size()};
630   }
631 
632   // Is 'this' before 'n' in the topological order?
633   bool isBefore(const Node* n) const;
634 
635   // Is 'this' after 'n' in the topological order?
636   bool isAfter(const Node* n) const;
637 
638   // Insert unattached 'this' node before 'n' in the topological order.
639   // Returns this (for chaining).
640   //
641   // Given:   %3 = f(%1, %2)
642   //          %4 = g(%3)
643   // and unattached: %5 = h(%1)
644   // Execute: %5.insertBefore(%4)
645   // Result:  %3 = f(%1, %2)
646   //          %5 = h(%1)
647   //          %4 = g(%3)
648   Node* insertBefore(Node* n);
649 
650   // Insert unattached 'this' node after 'n' in the topological order.
651   // Returns this (for chaining).
652   //
653   // Given: %3 = f(%1, %2)
654   //        %4 = g(%3)
655   // and unattached: %5 = h(%1)
656   // Execute: %5.insertAfter(%4)
657   // Result:  %3 = f(%1, %2)
658   //          %4 = g(%3)
659   //          %5 = h(%1)
660   Node* insertAfter(Node* n);
661 
662   // Move 'this' (already in the graph) after 'n' in the topological order.
663   //
664   // NOTE: Does not check that value dependencies are preserved, see
665   //   AliasDb::moveAfterTopologicallyValid
666   //
667   // Given: %2 = f(%1)
668   //        %3 = g(%1)
669   // Execute: %2.moveAfter(%3)
670   // Result: %3 = g(%1)
671   //         %2 = f(%1)
672   //
673   void moveAfter(Node* n);
674 
675   // Move a node 'n' (already in the graph) before 'this' in the topological
676   // order.
677   //
678   // NOTE: Does not check that value dependencies are preserved, see
679   //   AliasDb::moveBeforeTopologicallyValid
680   //
681   // Given: %2 = f(%1)
682   //        %3 = g(%1)
683   // Execute: %3.moveBefore(%2)
684   // Result: %3 = g(%1)
685   //         %2 = f(%1)
686   void moveBefore(Node* n);
687 
688   // Remove the input at 'i' from this node.
689   //
690   // WARNING: This is O(n) in the number of inputs, so avoid repeatedly calling
691   // removeInput.
692   //
693   // Given: %3 = f(%1, %2)
694   // Execute: %3.removeInput(1)
695   // Result: %3 = f(%1)
696   void removeInput(size_t i);
697 
698   // Remove all inputs from a node.
699   //
700   // Given: %3 = f(%1, %2)
701   // Execute: %3.removeAllInputs()
702   // Result: %3 = f()
703   void removeAllInputs();
704 
705   // Remove all outputs from a node.
706   //
707   // Given: %1, %2 = f()
708   // Execute:removeAllInputs()
709   // Result: = f()
710   void removeAllOutputs();
711 
712   // Rearrange the ordering of inputs or outputs of a node
713   // Given: %3 = f(%1, %2)
714   // Execute: %3.permuteInputs({1, 0})
715   // Result: %3 = f(%2, %1)
716   // Each index must appear exactly once
717   void permuteInputs(const std::vector<size_t>& new_inputs);
718   void permuteOutputs(const std::vector<size_t>& new_inputs);
719 
720   // iterators of the node list starting at this node
721   // useful for resuming a search starting at this node
iteratorNode722   inline graph_node_list_iterator iterator() {
723     return {this, 0};
724   }
reverseIteratorNode725   inline graph_node_list_iterator reverseIterator() {
726     return iterator().reverse();
727   }
iteratorNode728   inline const_graph_node_list_iterator iterator() const {
729     return {this, 0};
730   }
reverseIteratorNode731   inline const_graph_node_list_iterator reverseIterator() const {
732     return iterator().reverse();
733   }
734 
735   // Remove 'this' from the instruction list and deallocate it.
736   //
737   // Invariant: no outputs of 'this' may have any uses.
738   //
739   // Given: %2 = f(%1)
740   //        %3 = g(%1)
741   // Execute: %2.destroy()
742   // Result: %3 = g(%1)
743   void destroy();
744 
745   // Dynamically cast this node to the subclass indicated by the
746   // template variable, returning nullptr if the cast is invalid..
747   //
748   // Example usage: if(auto s = n.cast<Select>()) { ... }
749   template <typename T>
castNode750   T* cast() {
751     if (T::Kind == kind()) {
752       return static_cast<T*>(this);
753     }
754     return nullptr;
755   }
756   template <typename T>
castNode757   const T* cast() const {
758     if (T::Kind == kind()) {
759       return static_cast<const T*>(this);
760     }
761     return nullptr;
762   }
763 
764   template <typename T>
expectNode765   T* expect() {
766     TORCH_CHECK(
767         T::Kind == kind(),
768         "expected a ",
769         T::Kind.toDisplayString(),
770         " but found a ",
771         kind().toDisplayString());
772     return static_cast<T*>(this);
773   }
774 
775   bool matches(const FunctionSchema& schema) const;
776 
777   // XXX: this function is meant to be used with string literals only!
778   bool matches(
779       const char* signature_literal,
780       at::ArrayRef<Symbol> const_inputs = {}) const;
781 
782   bool isMemberOf(const OperatorSet& os) const;
783   template <typename T>
isMemberOfNode784   bool isMemberOf(const OperatorMap<T>& om) const {
785     auto it = om.map.find(kind());
786     if (it == om.map.end()) {
787       return false;
788     }
789     for (auto& op : it->second) {
790       if (matches(op.first->schema())) {
791         return true;
792       }
793     }
794     return false;
795   }
796 
797   const FunctionSchema& schema() const;
798   const FunctionSchema* maybeSchema() const;
799   const Operator& getOperator() const;
800   Operation getOperation() const;
801 
802   const Operator* maybeOperator() const;
803 
804   void dump() const;
805 
806   std::ostream& print(
807       std::ostream& out,
808       size_t level,
809       std::vector<const Node*>* groups,
810       bool print_source_locations = true,
811       bool print_attributes = true,
812       bool print_scopes = true,
813       bool print_body = true) const;
814 
~NodeNode815   virtual ~Node() {
816     if (wrap_) {
817       wrap_->clear();
818     }
819   }
820 
821   // Methods for accessing attributes
copyAttributesNode822   Node* copyAttributes(const Node& rhs) {
823     values_.clear();
824     for (const AVPtr& i : rhs.values_) {
825       values_.push_back(i->clone());
826     }
827     return this;
828   }
hasAttributeNode829   bool hasAttribute(Symbol name) const {
830     AT_ASSERT(name.is_attr());
831     return findAttr(name, false) != values_.end();
832   }
hasAttributeSNode833   bool hasAttributeS(const std::string& name) const {
834     return hasAttribute(Symbol::attr(name));
835   }
kindOfNode836   AttributeKind kindOf(Symbol name) const {
837     AT_ASSERT(name.is_attr());
838     return (*findAttr(name, true))->kind();
839   }
kindOfSNode840   AttributeKind kindOfS(const std::string& name) const {
841     return kindOf(Symbol::attr(name));
842   }
removeAttributeNode843   Node* removeAttribute(Symbol name) {
844     AT_ASSERT(name.is_attr());
845     values_.erase(findAttr(name, true));
846     return this;
847   }
removeAttributeSNode848   Node* removeAttributeS(const std::string& name) {
849     return removeAttribute(Symbol::attr(name));
850   }
hasAttributesNode851   bool hasAttributes() const {
852     return !values_.empty();
853   }
numAttributesNode854   size_t numAttributes() const {
855     return values_.size();
856   }
857   // The names are returned in order, since name actually is the index.
attributeNamesNode858   std::vector<Symbol> attributeNames() const {
859     std::vector<Symbol> names;
860     names.reserve(values_.size());
861     for (const AVPtr& a : values_) {
862       names.push_back(a->name);
863     }
864     return names;
865   }
attributeNamesSNode866   std::vector<const char*> attributeNamesS() const {
867     std::vector<const char*> names;
868     names.reserve(values_.size());
869     for (const AVPtr& a : values_) {
870       names.push_back(a->name.toUnqualString());
871     }
872     return names;
873   }
874 
875 #define CREATE_ACCESSOR(Kind, method)                           \
876   Node* method##_(Symbol name, Kind##Attr::ConstructorType v) { \
877     return setAttr<Kind##Attr>(                                 \
878         name, std::forward<Kind##Attr::ConstructorType>(v));    \
879   }                                                             \
880   const Kind##Attr::ValueType& method(Symbol name) const {      \
881     return getAttr<Kind##Attr>(name);                           \
882   }
883 
CREATE_ACCESSORNode884   CREATE_ACCESSOR(Float, f)
885   CREATE_ACCESSOR(Complex, c)
886   CREATE_ACCESSOR(Floats, fs)
887   CREATE_ACCESSOR(ComplexVals, cs)
888   CREATE_ACCESSOR(String, s)
889   CREATE_ACCESSOR(Strings, ss)
890   CREATE_ACCESSOR(Int, i)
891   CREATE_ACCESSOR(Ints, is)
892   CREATE_ACCESSOR(Graph, g)
893   CREATE_ACCESSOR(Graphs, gs)
894   CREATE_ACCESSOR(Type, ty)
895   CREATE_ACCESSOR(Types, tys)
896   CREATE_ACCESSOR(IValue, ival)
897 
898 #undef CREATE_ACCESSOR
899 
900   // Our Graphs are not very const-correct, so we need to allow returning
901   // non-const references too
902   GraphAttr::ValueType& g(Symbol name) {
903     return getAttr<GraphAttr>(name);
904   }
905 
906   // does not use CREATE_ACCESSOR because we need additional asserts
t_Node907   Node* t_(Symbol name, TensorAttr::ConstructorType v) {
908     return setAttr<TensorAttr>(
909         name, std::forward<TensorAttr::ConstructorType>(v));
910   }
tNode911   const TensorAttr::ValueType& t(Symbol name) const {
912     return getAttr<TensorAttr>(name);
913   }
914 
ts_Node915   Node* ts_(Symbol name, TensorsAttr::ConstructorType v) {
916     return setAttr<TensorsAttr>(
917         name, std::forward<TensorsAttr::ConstructorType>(v));
918   }
tsNode919   const TensorsAttr::ValueType& ts(Symbol name) const {
920     return getAttr<TensorsAttr>(name);
921   }
922 
923   Block* findCommonAncestorBlockWith(Node* n);
924 
925   size_t blocksFromGraphBlock();
926 
927  private:
928   void printAttrValue(std::ostream& out, const Symbol& name) const;
929   void printAttributes(std::ostream& out, bool ignore_subgraph) const;
930 
931   template <typename T>
setAttrNode932   Node* setAttr(Symbol name, typename T::ConstructorType v) {
933     AT_ASSERT(name.is_attr());
934     auto it = findAttr(name, false);
935     auto nv = AVPtr(new T(name, std::forward<typename T::ConstructorType>(v)));
936     // NOLINTNEXTLINE(bugprone-branch-clone)
937     if (it == values_.end()) {
938       values_.push_back(std::move(nv));
939     } else {
940       *it = std::move(nv);
941     }
942     return this;
943   }
944   template <typename T>
getAttrNode945   typename T::ValueType& getAttr(Symbol name) const {
946     AT_ASSERT(name.is_attr());
947     auto it = findAttr(name, true);
948     auto* child = dynamic_cast<T*>(it->get());
949     if (child == nullptr) {
950       throw IRAttributeError(name, true);
951     }
952     return child->value();
953   }
954   using AVPtr = AttributeValue::Ptr;
955   // NB: For determinism, we use a vector rather than a hash map.  This does
956   // mean that lookups are O(n), so you shouldn't use Attributes to store
957   // a big pile of messages.
958   std::vector<AVPtr> values_;
findAttrNode959   std::vector<AVPtr>::iterator findAttr(Symbol name, bool required) {
960     AT_ASSERT(name.is_attr());
961     auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) {
962       return v->name == name;
963     });
964     if (required && it == values_.end()) {
965       throw IRAttributeError(name, false);
966     }
967     AT_ASSERT(!required || it != values_.end());
968     return it;
969   }
findAttrNode970   std::vector<AVPtr>::const_iterator findAttr(Symbol name, bool required)
971       const {
972     AT_ASSERT(name.is_attr());
973     auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) {
974       return v->name == name;
975     });
976     if (required && it == values_.end()) {
977       throw IRAttributeError(name, false);
978     }
979     AT_ASSERT(!required || it != values_.end());
980     return it;
981   }
982 
983   enum class MoveSide { BEFORE, AFTER };
984   bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
985 
986   std::pair<Value*, const Argument&> findInput(Symbol name);
987   // Lookup iterator in use list of _input i_ that corresponds to its use of
988   // _this_
989   use_list::iterator findUseForInput(size_t i);
990 
991   // remove the use of input i, this sets input i to nullptr, but
992   // is only used internally to Node before setting it to a new value
993   // or erasing the entry from the list.
994   Value* dropInput(size_t i);
995 
inBlockListNode996   bool inBlockList() const {
997     if (next() == nullptr) {
998       AT_ASSERT(prev() == nullptr);
999     }
1000     return next() != nullptr;
1001   }
1002 
1003   void removeFromList();
1004   void lint() const;
1005 
1006   void assignTopoPosition();
1007 
1008  protected:
1009   // subclasses must override
1010   // this function is used by createClone to initialize a new version
1011   // of a node in another graph. It should allocate a new instance of the same
1012   // concrete type as 'this', but in graph 'g' which might be different
1013   // than graph_
allocNewInstanceNode1014   virtual Node* allocNewInstance(Graph* g) {
1015     return new Node(g, kind());
1016   }
1017   // create a copy of all properties of Node s into this.
1018   // subclasses should extend if they have additional information to copy.
1019   // 'this' will be allocated with s->allocNewInstance(g) so it should have
1020   // the same concrete type as 's'
1021   virtual void cloneFrom(Node* s);
1022 };
1023 
1024 struct Block {
1025   friend struct Node;
1026   friend struct Graph;
1027 
1028   AT_DISALLOW_COPY_AND_ASSIGN(Block);
1029   TORCH_API Block(Graph* graph_, Node* node_);
1030 
inputsBlock1031   at::ArrayRef<Value*> inputs() {
1032     return input_->outputs();
1033   }
inputsBlock1034   at::ArrayRef<const Value*> inputs() const {
1035     const auto& inputs = input_->outputs();
1036     return {inputs.data(), inputs.size()};
1037   }
outputsBlock1038   at::ArrayRef<Value*> outputs() {
1039     return output_->inputs();
1040   }
outputsBlock1041   at::ArrayRef<const Value*> outputs() const {
1042     return static_cast<const Node*>(output_)->inputs();
1043   }
nodesBlock1044   graph_node_list nodes() {
1045     return {input_, kNextDirection};
1046   }
nodesBlock1047   const_graph_node_list nodes() const {
1048     return {input_, kNextDirection};
1049   }
return_nodeBlock1050   Node* return_node() {
1051     return output_;
1052   }
return_nodeBlock1053   const Node* return_node() const {
1054     return output_;
1055   }
param_nodeBlock1056   Node* param_node() {
1057     return input_;
1058   }
param_nodeBlock1059   const Node* param_node() const {
1060     return input_;
1061   }
1062   /**
1063    * @warning NEVER pass raw pointer of smart pointer managed Graph to Python.
1064    * Check #87343 for details.
1065    */
owningGraphBlock1066   Graph* owningGraph() {
1067     return graph_;
1068   }
owningGraphBlock1069   const Graph* owningGraph() const {
1070     return graph_;
1071   }
owningNodeBlock1072   Node* owningNode() {
1073     return owning_node_;
1074   }
owningNodeBlock1075   const Node* owningNode() const {
1076     return owning_node_;
1077   }
1078 
1079   Value* addInput(const std::string& name = "") {
1080     Value* v = input_->addOutput();
1081     v->setDebugName(name);
1082     return v;
1083   }
1084   Value* insertInput(size_t i, const std::string& name = "") {
1085     Value* v = input_->insertOutput(i);
1086     v->setDebugName(name);
1087     return v;
1088   }
eraseInputBlock1089   void eraseInput(size_t i) {
1090     input_->eraseOutput(i);
1091   }
removeAllInputsBlock1092   void removeAllInputs() {
1093     input_->removeAllOutputs();
1094   }
registerOutputBlock1095   size_t registerOutput(Value* v) {
1096     output_->addInput(v);
1097     return outputs().size() - 1;
1098   }
insertOutputBlock1099   size_t insertOutput(size_t i, Value* n) {
1100     output_->insertInput(i, n);
1101     return i;
1102   }
eraseOutputBlock1103   void eraseOutput(size_t i) {
1104     output_->removeInput(i);
1105   }
removeAllOutputsBlock1106   void removeAllOutputs() {
1107     output_->removeAllInputs();
1108   }
1109 
replaceOutputBlock1110   void replaceOutput(size_t i, Value* n) {
1111     output_->replaceInput(i, n);
1112   }
permuteOutputsBlock1113   void permuteOutputs(const std::vector<size_t>& new_inputs) {
1114     output_->permuteInputs(new_inputs);
1115   }
permuteInputsBlock1116   void permuteInputs(const std::vector<size_t>& new_inputs) {
1117     input_->permuteOutputs(new_inputs);
1118   }
1119 
appendNodeBlock1120   Node* appendNode(Node* n) {
1121     AT_ASSERT(n->graph_ == graph_ && !n->inBlockList());
1122     n->insertBefore(output_);
1123     return n;
1124   }
prependNodeBlock1125   Node* prependNode(Node* n) {
1126     AT_ASSERT(n->graph_ == graph_ && !n->inBlockList());
1127     n->insertAfter(input_);
1128     return n;
1129   }
1130 
1131   // clone all inputs, nodes, and outputs from src and append them
1132   // to the inputs, nodes, and outputs of this block
1133   // value_map is used whenever a node in src references a free variable
1134   // in src to look up its corresponding value
1135   TORCH_API void cloneFrom(Block* src, std::function<Value*(Value*)> value_map);
1136   TORCH_API void remapTypes(const std::function<TypePtr(TypePtr)>& type_map);
1137 
wrapBlock1138   TORCH_API std::shared_ptr<Wrap<Block>> wrap() {
1139     if (!wrap_) {
1140       wrap_ = std::make_shared<Wrap<Block>>(this);
1141     }
1142     return wrap_;
1143   }
1144 
~BlockBlock1145   virtual ~Block() {
1146     if (wrap_) {
1147       wrap_->clear();
1148     }
1149   }
1150 
clearBlock1151   void clear() {
1152     removeAllOutputs();
1153     for (auto it = nodes().rbegin(); it != nodes().rend(); it++) {
1154       it.destroyCurrent();
1155     }
1156     removeAllInputs();
1157   }
1158 
1159  private:
1160   void reIndexTopology();
1161 
1162   // get rid of all nodes
1163   // destroys in reverse order so that uses internal to this block
1164   // do not have to be removed before you can destroy the block
1165   void destroy();
1166 
1167   Graph* const graph_;
1168   // holds outputs in a way that can be reflected
1169   // as a Use object
1170   // also used as the beginning/end of the circular node list to avoid
1171   // having corner cases where the list is empty.
1172   Node* const output_;
1173   Node* const input_;
1174   Node* const
1175       owning_node_; // either the node that has this block or nullptr for root
1176   // a managing wrapper for Python to allow invalidation
1177   std::shared_ptr<Wrap<Block>> wrap_;
1178 };
1179 
1180 struct Graph : std::enable_shared_from_this<Graph> {
1181   AT_DISALLOW_COPY_AND_ASSIGN(Graph);
1182   friend struct Node;
1183   friend struct Value;
1184   friend struct Block;
1185 
1186  private:
1187   // only used to keep track of allocated nodes
1188   // actual representation of Graph is done with
1189   // inputs, outputs, nodes
1190 
1191   std::unordered_set<const Node*> all_nodes;
1192   std::unordered_set<const Value*> all_values;
1193   std::unordered_set<const Block*> all_blocks;
1194   size_t next_unique_{0};
1195 
1196   std::unordered_map<std::string, Value*> unique_names_;
1197   // name_base_suffix tracks largest suffix currently used by all names sharing
1198   // same name_base. Key of this map is name_base, value is largest suffix
1199   // numeric value.
1200   std::unordered_map<std::string, size_t> name_base_suffix_;
1201 
1202   ScopePtr current_scope_;
1203 
1204   Block* const block_;
1205   // when insertNode() is called, the node is inserted before this node
1206   // by default this is set to append to the top level block
1207   Node* insert_before_;
1208   int64_t predicted_insert_count_ = 0;
1209 
1210   std::optional<size_t> op_version_;
1211 
1212  public:
1213   Graph(ScopePtr scope_root = c10::make_intrusive<Scope>())
current_scope_Graph1214       : current_scope_(std::move(scope_root)),
1215         block_(new Block(this, nullptr)),
1216         insert_before_(return_node()) {}
1217 
inputsGraph1218   at::ArrayRef<Value*> inputs() {
1219     return block_->inputs();
1220   }
inputsGraph1221   at::ArrayRef<const Value*> inputs() const {
1222     const Block& block = *block_;
1223     return block.inputs();
1224   }
outputsGraph1225   at::ArrayRef<Value*> outputs() {
1226     return block_->outputs();
1227   }
outputsGraph1228   at::ArrayRef<const Value*> outputs() const {
1229     const Block& block = *block_;
1230     return block.outputs();
1231   }
nodesGraph1232   graph_node_list nodes() {
1233     return block_->nodes();
1234   }
nodesGraph1235   const_graph_node_list nodes() const {
1236     const Block& block = *block_;
1237     return block.nodes();
1238   }
param_nodeGraph1239   Node* param_node() {
1240     return block_->param_node();
1241   }
param_nodeGraph1242   const Node* param_node() const {
1243     return block_->param_node();
1244   }
return_nodeGraph1245   Node* return_node() {
1246     return block_->return_node();
1247   }
return_nodeGraph1248   const Node* return_node() const {
1249     return block_->return_node();
1250   }
debugNamesGraph1251   const std::unordered_map<std::string, Value*>& debugNames() const {
1252     return unique_names_;
1253   }
1254 
1255   TORCH_API void push_scope(const std::string& scope_name);
1256   TORCH_API void pop_scope();
1257 
current_scopeGraph1258   ScopePtr current_scope() {
1259     return current_scope_;
1260   }
1261 
set_op_versionGraph1262   void set_op_version(std::optional<size_t> version) {
1263     op_version_ = version;
1264   }
1265 
get_op_versionGraph1266   std::optional<size_t> get_op_version() {
1267     return op_version_;
1268   }
1269 
set_current_scopeGraph1270   void set_current_scope(ScopePtr scope) {
1271     current_scope_ = std::move(scope);
1272   }
1273 
1274   Value* addInput(const std::string& name = "") {
1275     return block_->addInput(name);
1276   }
1277   Value* insertInput(size_t i, const std::string& name = "") {
1278     return block_->insertInput(i, name);
1279   }
eraseInputGraph1280   void eraseInput(size_t i) {
1281     block_->eraseInput(i);
1282   }
registerOutputGraph1283   size_t registerOutput(Value* n) {
1284     return block_->registerOutput(n);
1285   }
eraseOutputGraph1286   void eraseOutput(size_t i) {
1287     block_->eraseOutput(i);
1288   }
1289 
1290   TORCH_API Node* create(NodeKind kind, size_t num_outputs = 1);
1291   TORCH_API Node* create(
1292       NodeKind kind,
1293       ArrayRef<Value*> inputs,
1294       size_t num_outputs = 1);
1295 
1296   TORCH_API Node* createNone();
1297   TORCH_API Node* createAutogradZero();
1298   TORCH_API Node* createUninitialized(TypePtr typ);
1299   TORCH_API Node* createWithSubgraph(Symbol kind);
1300   TORCH_API Node* createDifferentiableSubgraph();
1301   TORCH_API Node* createTuple(
1302       at::ArrayRef<Value*> values,
1303       TupleTypePtr optional_named_tuple = nullptr);
1304   TORCH_API Node* createTupleUnpack(Value* v);
1305   TORCH_API Node* createTupleIndex(
1306       Value* tup,
1307       Value* idx,
1308       const TypePtr& output_type);
1309   TORCH_API Node* createTupleSlice(
1310       Value* tup,
1311       int64_t beg,
1312       int64_t step_size,
1313       int64_t num_values);
1314   TORCH_API Node* createEnumName(Value* e);
1315   TORCH_API Node* createEnumValue(Value* e);
1316   TORCH_API Node* createList(
1317       const TypePtr& contained_type,
1318       at::ArrayRef<Value*> values);
1319   TORCH_API Node* createListUnpack(Value* v, size_t size);
1320   TORCH_API Node* createDict(
1321       const TypePtr& key_type,
1322       const TypePtr& value_type,
1323       at::ArrayRef<Value*> keys,
1324       at::ArrayRef<Value*> values);
1325   TORCH_API Node* createNumToTensor(Value* value);
1326   TORCH_API Node* createObject(const ClassTypePtr& type);
1327   TORCH_API Node* createSetAttr(
1328       Value* obj,
1329       const std::string& field,
1330       Value* newValue);
1331   TORCH_API Node* createGetAttr(Value* obj, const std::string& field);
insertGetAttrGraph1332   Value* insertGetAttr(Value* obj, const std::string& field) {
1333     return insertNode(createGetAttr(obj, field))->output();
1334   }
1335   TORCH_API Node* createStore(const std::string& name, Value* v);
1336   TORCH_API Node* createLoad(const std::string& name, const TypePtr& type);
1337   TORCH_API Node* createIsInstance(Value* v, at::ArrayRef<TypePtr> types);
1338 
1339   TORCH_API Value* insertUncheckedCast(Value* v, TypePtr type);
1340 
1341   // Insert a ToList operator with argument \p v and output type \p type.
1342   // \returns the output of the operation.
1343   TORCH_API Value* insertToList(Value* v, TypePtr type);
1344 
1345   TORCH_API Value* insertFunctionCall(
1346       Function* callee,
1347       const MatchedSchema& matched);
1348   TORCH_API Value* insertMethodCall(
1349       std::string method_name,
1350       const MatchedSchema& matched);
1351 
1352   // Note: defined in python_ir.cpp and can be used only in python extension
1353   Node* createPythonOp(
1354       THPObjectPtr&& pyobj,
1355       const std::string& cconv,
1356       pyobj_list&& scalar_args);
1357   // clone n, making a new node in _this_ graph.
1358   // use value_map to translate inputs of n to inputs of the cloned node
1359   // if copy_blocks is false, it will not recursively clone the nested blocks
1360   // this node contains.
1361   TORCH_API Node* createClone(
1362       Node* n,
1363       const std::function<Value*(Value*)>& value_map,
1364       bool copy_blocks = true);
1365 
1366   // Insert constant IValue into the graph.
1367   TORCH_API Value* insertConstant(
1368       const IValue& val,
1369       std::optional<SourceRange> loc = std::nullopt,
1370       std::optional<ScopePtr> scope = std::nullopt);
1371 
1372   // Schema-driven insert:
1373   // This inserts a node into the graph with inputs determined from args and
1374   // kwargs using Python argument matching rules, and checks that the op matches
1375   // a known schema.
1376   //
1377   // If this node successfully completes, it guarentees the node
1378   // is a correctly-formed invocation of opname
1379   TORCH_API Value* insert(
1380       Symbol opname,
1381       at::ArrayRef<NamedValue> args,
1382       at::ArrayRef<NamedValue> kwargs = {},
1383       const std::optional<SourceRange>& range = {});
1384 
appendNodeGraph1385   Node* appendNode(Node* n) {
1386     return block_->appendNode(n);
1387   }
1388 
prependNodeGraph1389   Node* prependNode(Node* n) {
1390     return block_->prependNode(n);
1391   }
1392 
1393   // insert before insert_before_ node
1394   // initialized to insert at the end of the top level block
1395   // can be changed with setInsertPoint()
insertNodeGraph1396   Node* insertNode(Node* n) {
1397     AT_ASSERT(
1398         insert_before_->inBlockList() &&
1399         "insert point node is no longer in a block list");
1400     return n->insertBefore(insert_before_);
1401   }
1402   // set where nodes are inserted to append to the end of this block
setInsertPointGraph1403   void setInsertPoint(Block* b) {
1404     AT_ASSERT(b->owningGraph() == this);
1405     setInsertPoint(b->return_node());
1406   }
1407   // set where nodes are inserted to insert _before_ this node
1408   // for implementation simplicity we only support inserting before a node for
1409   // now
setInsertPointGraph1410   void setInsertPoint(Node* n) {
1411     AT_ASSERT(n->owningGraph() == this && n->inBlockList());
1412     insert_before_ = n;
1413     predicted_insert_count_ = 0;
1414   }
insertPointGraph1415   Node* insertPoint() {
1416     return insert_before_;
1417   }
1418 
1419   // the top level block
blockGraph1420   Block* block() {
1421     return block_;
1422   }
blockGraph1423   const Block* block() const {
1424     return block_;
1425   }
1426 
1427   // Checks well-formedness and invariants of graph
1428   TORCH_API void lint() const;
1429   // for use in debugger
1430   TORCH_API void dump() const;
1431 
1432   TORCH_API ~Graph();
1433 
1434   TORCH_API std::string toString(bool print_source_locations = true) const;
1435 
1436   TORCH_API std::ostream& print(
1437       std::ostream& out,
1438       bool print_source_locations = true) const;
1439 
1440   friend TORCH_API std::ostream& operator<<(std::ostream& out, const Graph& g);
1441 
1442   TORCH_API std::shared_ptr<Graph> copy();
1443   TORCH_API std::unique_ptr<Graph> copyUnique();
1444   TORCH_API void remapTypes(const std::function<TypePtr(TypePtr)>& type_map);
1445 
1446  private:
1447   friend TORCH_API void Lint(const AliasDb* db);
1448   TORCH_API void freeNode(Node* n);
1449   TORCH_API void freeValue(Value* v);
1450   TORCH_API void freeBlock(Block* b);
1451   void cloneFrom(Graph& src);
1452 };
1453 
1454 /** \brief An utility class for setting temporary insertion points.
1455  *
1456  * When an object of this class is created, it stores the current insertion
1457  * point, sets the new one, and restores the original insertion point when the
1458  * object is destroyed.
1459  */
1460 struct WithInsertPoint {
WithInsertPointWithInsertPoint1461   WithInsertPoint(Node* n) : prev_(n->owningGraph()->insertPoint()) {
1462     n->owningGraph()->setInsertPoint(n);
1463   }
WithInsertPointWithInsertPoint1464   WithInsertPoint(Block* b) : WithInsertPoint(b->return_node()) {}
1465 
~WithInsertPointWithInsertPoint1466   ~WithInsertPoint() {
1467     prev_->owningGraph()->setInsertPoint(prev_);
1468   }
1469 
1470  private:
1471   Node* prev_;
1472 };
1473 
1474 /** \brief An utility class for setting temporary scopes.
1475  *
1476  * When an object of this class is created, it stores the current scope, sets
1477  * the new one, and restores the original scope when the object is destroyed.
1478  */
1479 struct WithCurrentScope {
WithCurrentScopeWithCurrentScope1480   WithCurrentScope(Graph& g, ScopePtr scope)
1481       : graph_(&g), prev_scope_(g.current_scope()) {
1482     g.set_current_scope(std::move(scope));
1483   }
~WithCurrentScopeWithCurrentScope1484   ~WithCurrentScope() {
1485     graph_->set_current_scope(prev_scope_);
1486   }
1487 
1488  private:
1489   Graph* graph_;
1490   ScopePtr prev_scope_;
1491 };
1492 
1493 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Value(Node * node_,size_t offset_)1494 inline Value::Value(Node* node_, size_t offset_)
1495     : node_(node_),
1496       offset_(offset_),
1497       unique_(node_->graph_->next_unique_++),
1498       type_(TensorType::get()) {
1499   node_->graph_->all_values.emplace(this);
1500 }
1501 
setType(TypePtr type)1502 inline Value* Value::setType(TypePtr type) {
1503   AT_ASSERT(type);
1504   if (auto dyn = type->castRaw<c10::DynamicType>()) {
1505     type = dyn->fallback();
1506   }
1507   type_ = std::move(type);
1508   for (Use& use : uses_) {
1509     use.user->op_ = nullptr;
1510   }
1511   return this;
1512 }
1513 
owningGraph()1514 inline Graph* Value::owningGraph() {
1515   return node()->owningGraph();
1516 }
1517 
owningGraph()1518 inline const Graph* Value::owningGraph() const {
1519   return node()->owningGraph();
1520 }
1521 
1522 /************* All nodes not required to be defined before Graph **************/
1523 struct ProfileOp : public Node {
1524   static const Symbol Kind;
ProfileOpProfileOp1525   ProfileOp(Graph* graph, std::function<void(std::vector<IValue>&)> callback)
1526       : Node(graph, ::c10::prim::profile), callback_(std::move(callback)) {}
1527 
1528   void cloneFrom(Node* other_) override;
1529   Node* allocNewInstance(Graph* g) override;
1530 
getCallbackProfileOp1531   const std::function<void(std::vector<IValue>&)>& getCallback() const {
1532     return callback_;
1533   }
1534 
setCallbackProfileOp1535   void setCallback(std::function<void(std::vector<IValue>&)> callback) {
1536     callback_ = std::move(callback);
1537   }
1538 
hasSeenTensorProfileOp1539   bool hasSeenTensor() const {
1540     return has_seen_tensor_;
1541   }
1542 
setHasSeenTensorProfileOp1543   void setHasSeenTensor(bool has_seen_tensor) {
1544     has_seen_tensor_ = has_seen_tensor;
1545   }
1546 
1547  private:
1548   std::function<void(std::vector<IValue>&)> callback_;
1549   bool has_seen_tensor_ = false;
1550 };
1551 
1552 struct TORCH_API ProfileIValueOp : public Node {
1553   static const Symbol Kind;
ProfileIValueOpProfileIValueOp1554   ProfileIValueOp(
1555       Graph* graph,
1556       std::function<void(std::vector<IValue>&)> callback)
1557       : Node(graph, ::c10::prim::profile_ivalue),
1558         callback_(std::move(callback)) {}
1559 
1560   void cloneFrom(Node* other_) override;
1561   Node* allocNewInstance(Graph* g) override;
1562 
getCallbackProfileIValueOp1563   const std::function<void(std::vector<IValue>&)>& getCallback() const {
1564     return callback_;
1565   }
1566 
setCallbackProfileIValueOp1567   void setCallback(std::function<void(std::vector<IValue>&)> callback) {
1568     callback_ = std::move(callback);
1569   }
1570 
1571  private:
1572   std::function<void(std::vector<IValue>&)> callback_;
1573 };
1574 
1575 // execute a Python function, used for Ops we can't optimize but that we want to
1576 // optimize around
1577 //
1578 // Note: actual implementation (ConcretePythonOp) is defined in python_ir.cpp
1579 // which is not included in libtorch.so. We still include some bits and pieces
1580 // of PythonOp here to enable writing simple passes generically. In general,
1581 // python-aware bits need to be moved to the descendant classes.
1582 struct TORCH_API PythonOp : public Node {
1583   using Node::Node;
1584 
1585   virtual std::string name() const = 0;
1586   virtual void writeScalars(std::ostream& out) const = 0;
1587   void cloneFrom(Node* other_) override = 0;
1588   Node* allocNewInstance(Graph* g) override = 0;
1589   // recover the autograd.Function instance, if this PythonOp's function
1590   // was originally SomeFunction.apply
1591   // used in ONNX for discovering symbolics
1592   virtual std::optional<THPObjectPtr> autogradFunction() const = 0;
1593 
1594   virtual void lint_python() const = 0;
1595 };
1596 
1597 TORCH_API void LintGraph(const std::shared_ptr<Graph>& graph);
1598 
1599 TORCH_API at::ArrayRef<Value*> createTupleUnpack(Value* v);
1600 
1601 /** Insert graph \p CALLEE into graph \p G using \p INPUTS as input values.
1602  * The insertion happens at the current insertion point.
1603  * Optionally, one can also pass \p VALUE_MAP to get a map between \p CALLEE
1604  * values and their cloned copies in \p G.
1605  */
1606 TORCH_API std::vector<Value*> insertGraph(
1607     Graph& g,
1608     Graph& callee,
1609     ArrayRef<Value*> inputs);
1610 TORCH_API std::vector<Value*> insertGraph(
1611     Graph& g,
1612     Graph& callee,
1613     ArrayRef<Value*> inputs,
1614     std::unordered_map<Value*, Value*>& value_map);
1615 
1616 /** Insert function \p CALLEE after node \p TO_REPLACE, remove the node and
1617  * replace all its uses with corresponding outputs of the inserted function.
1618  * This asserts that the number of outputs of the original node and the
1619  * graph are the same.
1620  */
1621 TORCH_API std::vector<Value*> inlineCallTo(
1622     Node* to_replace,
1623     GraphFunction* callee,
1624     bool use_graph = true);
1625 
1626 TORCH_API std::vector<Value*> inlineCallTo(
1627     Node* to_replace,
1628     GraphFunction* callee,
1629     Graph* callee_graph);
1630 
1631 /** If there is only one value in \p OUTPUTS and its kind is Tuple, insert a
1632  * tuple unpack node and return the resulting values.
1633  */
1634 TORCH_API std::vector<Value*> unpackOutputs(const std::vector<Value*>& outputs);
1635 
1636 TORCH_API std::vector<Node*> findAllNodes(Graph& g, Symbol kind, bool recurse);
1637 TORCH_API std::vector<Node*> findAllNodes(Block& b, Symbol kind, bool recurse);
1638 TORCH_API std::vector<Node*> findAllNodes(
1639     at::ArrayRef<Block*> a,
1640     Symbol kind,
1641     bool recurse);
1642 
1643 struct TORCH_API OperatorSet {
1644   OperatorSet(std::initializer_list<const char*> sig_literals);
1645   std::vector<std::shared_ptr<Operator>> getOps() const;
1646   void insert(std::initializer_list<const char*> sig_literals);
1647 
1648  private:
1649   friend struct Node;
1650   std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>> ops;
1651 };
1652 
1653 template <typename T>
1654 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
1655 struct OperatorMap {
1656   // Type aliasing
1657   using OpMapType = typename std::pair<std::shared_ptr<Operator>, T>;
1658   using ValueType = std::vector<OpMapType>;
1659   using MapType = std::unordered_map<Symbol, ValueType>;
1660 
1661   OperatorMap() = default;
1662   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
OperatorMapOperatorMap1663   explicit OperatorMap(
1664       std::initializer_list<std::pair<std::shared_ptr<Operator>, T>> init) {
1665     insert(init);
1666   }
1667   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
OperatorMapOperatorMap1668   explicit OperatorMap(std::initializer_list<std::pair<const char*, T>> init) {
1669     insert(init);
1670   }
1671 
insertOperatorMap1672   void insert(const std::shared_ptr<Operator>& op, T val) {
1673     // Remove if exists before insert
1674     erase(op);
1675     map[Symbol::fromQualString(op->schema().name())].emplace_back(
1676         std::make_pair(op, val));
1677   }
1678 
insertOperatorMap1679   void insert(const OperatorSet& op_set, T val) {
1680     for (auto& op : op_set.getOps()) {
1681       insert(op, val);
1682     }
1683   }
1684 
insertOperatorMap1685   void insert(
1686       std::initializer_list<std::pair<std::shared_ptr<Operator>, T>> v) {
1687     for (auto& el : v) {
1688       insert(el.first, el.second);
1689     }
1690   }
1691 
insertOperatorMap1692   void insert(std::initializer_list<std::pair<const char*, T>> v) {
1693     for (auto& el : v) {
1694       insert(getOperatorForLiteral(el.first), el.second);
1695     }
1696   }
1697 
eraseOperatorMap1698   void erase(const std::shared_ptr<Operator>& op) {
1699     auto it = map.find(Symbol::fromQualString(op->schema().name()));
1700     if (it == map.end()) {
1701       return;
1702     }
1703     for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) {
1704       if (vit->first->schema() == op->schema()) {
1705         it->second.erase(vit);
1706         break;
1707       }
1708     }
1709     if (it->second.size() == 0) {
1710       map.erase(Symbol::fromQualString(op->schema().name()));
1711     }
1712   }
1713 
containsOperatorMap1714   bool contains(const Operator& op) const {
1715     const auto it = map.find(Symbol::fromQualString(op.schema().name()));
1716     if (it == map.end()) {
1717       return false;
1718     }
1719     for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) {
1720       if (vit->first->schema() == op.schema()) {
1721         return true;
1722       }
1723     }
1724     return false;
1725   }
1726 
containsOperatorMap1727   bool contains(const Node* n) const {
1728     return n->maybeOperator() && contains(n->getOperator());
1729   }
1730 
findOperatorMap1731   std::optional<T> find(const Operator& op) {
1732     const auto it = map.find(Symbol::fromQualString(op.schema().name()));
1733     if (it == map.end()) {
1734       return std::nullopt;
1735     }
1736     for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) {
1737       if (vit->first->schema() == op.schema()) {
1738         return vit->second;
1739       }
1740     }
1741     return std::nullopt;
1742   }
1743 
1744   // TODO: return iterator
getAllKeysAndValuesOperatorMap1745   std::vector<OpMapType> getAllKeysAndValues() const {
1746     std::vector<OpMapType> keys_values;
1747     keys_values.reserve(map.size());
1748     for (auto& symbol_mapping : map) {
1749       auto& vec = symbol_mapping.second;
1750       for (auto& pair : vec) {
1751         keys_values.push_back(pair);
1752       }
1753     }
1754     return keys_values;
1755   }
1756 
1757  private:
1758   friend struct Node;
1759   MapType map;
1760 };
1761 
1762 template <typename T>
1763 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
1764 struct FunctionSchemaMap {
1765   // Type aliasing
1766   using FuncSchemaMapType = typename std::pair<FunctionSchema, T>;
1767   using ValueType = std::vector<FuncSchemaMapType>;
1768   using MapType = std::unordered_map<Symbol, ValueType>;
1769 
1770   FunctionSchemaMap() = default;
insertFunctionSchemaMap1771   void insert(const FunctionSchema& schema, T val) {
1772     // Remove if exists before insert
1773     erase(schema);
1774     map[Symbol::fromQualString(schema.name())].emplace_back(
1775         std::make_pair(schema, val));
1776   }
1777 
eraseFunctionSchemaMap1778   void erase(const FunctionSchema& schema) {
1779     auto it = map.find(Symbol::fromQualString(schema.name()));
1780     if (it == map.end()) {
1781       return;
1782     }
1783     for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) {
1784       if (vit->first == schema) {
1785         it->second.erase(vit);
1786         break;
1787       }
1788     }
1789     if (it->second.size() == 0) {
1790       map.erase(Symbol::fromQualString(schema.name()));
1791     }
1792   }
1793 
containsFunctionSchemaMap1794   bool contains(const FunctionSchema& schema) const {
1795     const auto it = map.find(Symbol::fromQualString(schema.name()));
1796     if (it == map.end()) {
1797       return false;
1798     }
1799     for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) {
1800       if (vit->first->schema() == schema) {
1801         return true;
1802       }
1803     }
1804     return false;
1805   }
1806 
findFunctionSchemaMap1807   std::optional<T> find(const FunctionSchema& schema) const {
1808     const auto it = map.find(Symbol::fromQualString(schema.name()));
1809     if (it == map.end()) {
1810       return std::nullopt;
1811     }
1812     for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) {
1813       if (vit->first == schema) {
1814         return vit->second;
1815       }
1816     }
1817     return std::nullopt;
1818   }
1819 
1820   // TODO: return iterator
getAllKeysAndValuesFunctionSchemaMap1821   std::vector<FuncSchemaMapType> getAllKeysAndValues() const {
1822     std::vector<FuncSchemaMapType> keys_values;
1823     keys_values.reserve(map.size());
1824     for (auto& symbol_mapping : map) {
1825       auto& vec = symbol_mapping.second;
1826       for (auto& pair : vec) {
1827         keys_values.push_back(pair);
1828       }
1829     }
1830     return keys_values;
1831   }
1832 
1833  private:
1834   friend struct Node;
1835   MapType map;
1836 };
1837 
1838 } // namespace torch::jit
1839