1 #pragma once
2
3 #include <torch/csrc/jit/ir/attributes.h>
4 #include <torch/csrc/jit/ir/graph_node_list.h>
5 #include <torch/csrc/jit/ir/named_value.h>
6 #include <torch/csrc/jit/ir/scope.h>
7 #include <torch/csrc/jit/runtime/operator.h>
8
9 #include <torch/csrc/Export.h>
10 #include <torch/csrc/utils/python_stub.h>
11 #include <torch/csrc/utils/schema_info.h>
12
13 #include <ATen/Utils.h>
14 #include <ATen/core/Tensor.h>
15 #include <ATen/core/dynamic_type.h>
16 #include <ATen/core/enum_type.h>
17 #include <ATen/core/functional.h>
18 #include <ATen/core/interned_strings.h>
19 #include <ATen/core/ivalue.h>
20 #include <ATen/core/jit_type.h>
21 #include <c10/util/ArrayRef.h>
22 #include <c10/util/Exception.h>
23 #include <optional>
24
25 #include <functional>
26 #include <iosfwd>
27 #include <unordered_set>
28 #include <vector>
29
30 // Forward declare, the real meat is in python_ir.cpp
31 template <class T>
32 class THPPointer;
33 using THPObjectPtr = THPPointer<PyObject>;
34 using pyobj_list = std::vector<THPObjectPtr>;
35
36 namespace torch::jit {
37 namespace utils {
38 TORCH_API std::string getNodesModuleHierarchy(const Node& n);
39 } // namespace utils
40 class AliasDb;
41
42 using ::c10::Argument;
43 using ::c10::FunctionSchema;
44 using ::c10::Symbol;
45
46 using ::c10::ivalue::Shared;
47
48 using ::c10::IValue;
49 using ::c10::ivalue::Future;
50
51 using ::c10::ivalue::ConstantString;
52
53 #define C10_USING(T) using ::c10::T;
54 C10_FORALL_TYPES(C10_USING)
55 #undef C10_USING
56
57 #define C10_USING(T) using ::c10::T##Ptr;
58 C10_FORALL_TYPES(C10_USING)
59 #undef C10_USING
60
61 using ::c10::Type;
62 using ::c10::TypeEnv;
63 using ::c10::TypePtr;
64
65 using ::c10::getTypePtr;
66 using ::c10::MatchTypeReturn;
67 using ::c10::TypeKind;
68
69 using ::c10::fmap;
70
71 namespace prim {
72 using namespace ::c10::prim;
73 }
74 namespace attr {
75 using namespace ::c10::attr;
76 }
77 namespace aten {
78 using namespace ::c10::aten;
79 }
80 namespace cuda {
81 #if !defined(USE_ROCM)
82 using namespace ::c10::cuda;
83 #endif
84 } // namespace cuda
85
86 struct Function;
87 struct GraphFunction;
88 struct MatchedSchema;
89
90 // A Graph represents one "function" of computation.
91 // It uses a simple ownership model where the graph owns all the nodes inside
92 // it. All references inside the graph are raw pointers. Destroying the Graph
93 // will invalidate any pointers to nodes in the graph.
94 struct Graph;
95
96 // Node is the base class of the IR graph. It represents one computation
97 // and dependencies on a list of Values. The "prim-ops", so to speak.
98 struct Node;
99
100 // A Value represents an input or output to node that is either a
101 // Tensor or an opaque Handle object, as determined by type().
102 struct Value;
103
104 TORCH_API std::ostream& operator<<(std::ostream& out, const Graph& g);
105 TORCH_API std::ostream& operator<<(std::ostream& out, const Node& n);
106
107 // A list of nodes, with inputs and outputs
108 struct Block;
109
110 // Each use is represented by this type, see 'Node::uses()'
111 // 'user' is the consumer of the value, 'offset' is the index into
112 // 'user's input this where the producers will be found.
113 struct Use {
UseUse114 Use(Node* user, size_t offset) : user(user), offset(offset) {}
115 Node* user;
116 size_t offset;
117
118 bool operator==(const Use& b) {
119 return user == b.user && offset == b.offset;
120 }
121 };
122
123 // Note [User node does not uniquely identify use]
124 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
125 // A while back, we wrote some code manipulating uses that looked like this:
126 //
127 // for (auto& use : used_val->uses_) {
128 // if (use.user == this_node) {
129 // use.offset += 1;
130 // break;
131 // }
132 // }
133 //
134 // This code is trying to find a particular use (our node's use) to update it.
135 // However, it's wrong: there may be *multiple* uses of a value %x in a node,
136 // as might be the case in this IR:
137 //
138 // %y = Add %x %x
139 //
140 // In this case, there are two uses of %x whose user is the node 'Add %x %x'.
141 // So, "use induced by this node" is not a well-formed concept.
142 //
143 // If you are looking for "use induced by an input", it's best to use
144 // findUseForInput() to get it.
145
146 // the list types are intentionally simple, but we type-def
147 // them here so if we need to change them, refactoring will be easier
148 using node_list = std::vector<Node*>;
149 using value_list = std::vector<Value*>;
150 using use_list = std::vector<Use>;
151 template <typename T>
152 using ArrayRef = at::ArrayRef<T>;
153 using NodeKind = Symbol;
154 using topo_position_t = int64_t;
155 using ValueSet = std::unordered_set<const Value*>;
156
157 struct OperatorSet;
158 template <typename T>
159 struct OperatorMap;
160
161 // This is a wrapper to allow invalidating the Python object
162 // safely when the C++ object for a Node/Value/Block is deleted
163 // like much of graph, it isn't safe for different threads to
164 // access the same graph
165 template <typename T>
166 struct Wrap {
WrapWrap167 explicit Wrap(T* p) : elem(p) {}
clearWrap168 void clear() {
169 if (clear_cb) {
170 clear_cb(elem);
171 }
172 elem = nullptr;
173 }
174 T* elem;
175 void (*clear_cb)(void*){nullptr};
176 };
177
178 struct Value {
179 AT_DISALLOW_COPY_AND_ASSIGN(Value);
180 Value(Node* node_, size_t offset_);
181
182 private:
183 friend struct Node;
184 friend struct Graph;
185 Node* node_;
186 size_t offset_;
187 size_t unique_ = 0; // unique id
188 use_list uses_;
189 std::string unique_name_;
190 TypePtr type_;
191 // a managing wrapper for Python to allow invalidation
192 std::shared_ptr<Wrap<Value>> wrap_;
193
194 public:
195 Value* setType(TypePtr type);
196 TORCH_API void inferTypeFrom(const at::Tensor& output);
197 TORCH_API void inferTypeFrom(
198 const c10::intrusive_ptr<c10::ivalue::Object>& output);
typeValue199 const TypePtr& type() const {
200 AT_ASSERT(type_ != nullptr);
201 return type_;
202 }
requires_gradValue203 bool requires_grad() const {
204 return type()->requires_grad();
205 }
isCompleteTensorValue206 bool isCompleteTensor() const {
207 if (auto pt = type()->cast<TensorType>()) {
208 return pt->isComplete();
209 }
210 return false;
211 }
212 TORCH_API bool mustBeNone() const;
213 TORCH_API bool mustNotBeNone() const;
uniqueValue214 size_t unique() const {
215 return unique_;
216 }
hasDebugNameValue217 bool hasDebugName() const {
218 return !unique_name_.empty();
219 }
220 static bool isValidName(const std::string& name);
221 TORCH_API Value* setDebugName(const std::string& name);
debugNameValue222 std::string debugName() const {
223 if (hasDebugName()) {
224 return unique_name_;
225 }
226 return std::to_string(unique());
227 }
228 TORCH_API std::string debugNameBase() const;
nodeValue229 Node* node() {
230 return node_;
231 }
offsetValue232 size_t offset() const {
233 return offset_;
234 }
setOffsetValue235 void setOffset(size_t offset) {
236 offset_ = offset;
237 }
nodeValue238 const Node* node() const {
239 return node_;
240 }
241
242 /**
243 * @warning NEVER pass raw pointer of smart pointer managed Graph to Python.
244 * Check #87343 for details.
245 */
246 Graph* owningGraph();
247 const Graph* owningGraph() const;
248 // TODO: make this more const correct
usesValue249 const use_list& uses() const {
250 return uses_;
251 }
252
hasUsesValue253 bool hasUses() const {
254 return !uses().empty();
255 }
256
257 TORCH_API void replaceFirstUseWith(Value* newValue);
258
259 // Replaces all uses of this value with 'newValue'.
260 //
261 // Given: %3 = f(%1, %2)
262 // %4 = g(%3)
263 // %5 = h(%3, %3)
264 // Execute: %3.replaceAllUsesWith(%6)
265 // Result: %3 = f(%1, %2)
266 // %4 = g(%6)
267 // %5 = h(%6, %6)
268 TORCH_API void replaceAllUsesWith(Value* newValue);
269
270 // Replaces all uses of this value with 'newValue' after 'node'.
271 // Given: %3 = f(%1, %2)
272 // %4 = g(%3)
273 // %5 = inplace_(%3)
274 // %6 = h(%3, %3)
275 // Execute: %3.replaceAllUsesAfterNodeWith(%5.node(), %5)
276 // Result: %3 = f(%1, %2)
277 // %4 = g(%3)
278 // %5 = inplace_(%3)
279 // %6 = h(%5, %5)
280 // XXX: does not check scoping legality, consider using
281 // replaceAllUsesDominatedByNodeWith
282 TORCH_API void replaceAllUsesAfterNodeWith(const Node* node, Value* newValue);
283
284 // Replaces all uses of this value with 'newValue' that are dominated by
285 // 'node'. Given:
286 // x = op(...).
287 // if cond:
288 // z = foo(..)
289 // bar(x)
290 // else:
291 // print(x)
292 // x.replaceAllUsesDominatedByNodeWith(foo, z) would replace bar(x)
293 // but not print(x) because print is not dominated by foo.
294 // replaceAllUsesAfterNode does not check domination, so in this example
295 // it would produce invalid IR.
296 TORCH_API void replaceAllUsesDominatedByNodeWith(
297 const Node* node,
298 Value* newValue);
299
300 TORCH_API Value* copyMetadata(Value* from);
301
wrapValue302 TORCH_API std::shared_ptr<Wrap<Value>> wrap() {
303 if (!wrap_) {
304 wrap_ = std::make_shared<Wrap<Value>>(this);
305 }
306 return wrap_;
307 }
308
~ValueValue309 virtual ~Value() {
310 if (wrap_) {
311 wrap_->clear();
312 }
313 }
314 };
315
316 struct TORCH_API Node {
317 AT_DISALLOW_COPY_AND_ASSIGN(Node);
318 friend struct Graph;
319 friend struct Block;
320 friend struct Value;
321 friend graph_node_list;
322 friend const_graph_node_list;
323 friend graph_node_list_iterator;
324 friend const_graph_node_list_iterator;
325
326 private:
327 const NodeKind kind_;
328 std::vector<Value*> inputs_;
329 std::vector<Value*> outputs_;
330 // subblocks
331 std::vector<Block*> blocks_;
332 Graph* graph_;
333 Block* owning_block_;
334 std::optional<SourceRange> source_range_;
335 ScopePtr scope_;
336 std::optional<InlinedCallStackPtr> callstack_;
337 // Assumes FunctionSchemas are persistent, so we don't manage their lifetime.
338 // This field is effective a cache that's populated on attribute lookups and
339 // invalidated every time we perform an operation that could potentially
340 // change the schema. note: mutable because schema_ is effectively a cache
341 mutable const Operator* op_;
342 topo_position_t topo_position_ = 0;
343 // a managing wrapper for Python to allow invalidation
344 std::shared_ptr<Wrap<Node>> wrap_;
345 // Stores the full schema name, if the operator is historic
346 // When the operator is deprecated or the name of the operator
347 // is changed, we need to rely on this name
348 // to retrieve old schemas to successfully apply upgraders
349 // for this operator.
350 std::optional<std::string> historic_schema_name_ = std::nullopt;
351
352 protected:
353 Node(Graph* graph_, NodeKind kind_); // defined after graph
354 public:
355 // Each Node but Return/Param Nodes are associated with exactly one
356 // place in the Node list of the Graph. The Graph itself is a circular
357 // doubly-linked list. The Return Node is used as the sentinel for the
358 // "beginning"/"end" of the list. This means that you can tell when
359 // you've traversed the entire list without means worrying about null
360 // pointers. `next_in_graph[0]` is the pointer to the next Node, while
361 // `next_in_graph[1]` is the pointer to the previous Node. The
362 // linked list is implemented as an array to allow the same iterator
363 // class for forward and reversed Node lists. Taken together, this
364 // list also represents a topological sort of the Nodes in the Graph.
365 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-non-private-member-variables-in-classes,modernize-avoid-c-arrays)
366 Node* next_in_graph[2] = {nullptr, nullptr};
367
wrapNode368 std::shared_ptr<Wrap<Node>> wrap() {
369 if (!wrap_) {
370 wrap_ = std::make_shared<Wrap<Node>>(this);
371 }
372 return wrap_;
373 }
374
getHistoricSchemaNameNode375 const std::optional<std::string> getHistoricSchemaName() {
376 return historic_schema_name_;
377 }
378
setHistoricSchemaNameNode379 void setHistoricSchemaName(const std::string& name) {
380 historic_schema_name_ = name;
381 }
382
nextNode383 Node*& next() {
384 return next_in_graph[kNextDirection];
385 }
prevNode386 Node*& prev() {
387 return next_in_graph[kPrevDirection];
388 }
nextNode389 Node* const& next() const {
390 return next_in_graph[kNextDirection];
391 }
prevNode392 Node* const& prev() const {
393 return next_in_graph[kPrevDirection];
394 }
395
kindNode396 NodeKind kind() const {
397 return kind_;
398 }
setSourceRangeNode399 Node* setSourceRange(SourceRange r) {
400 source_range_ = std::move(r);
401 return this;
402 }
403 SourceRange sourceRange() const;
404
405 /**
406 * @warning NEVER pass raw pointer of smart pointer managed Graph to Python.
407 * Check #87343 for details.
408 */
owningGraphNode409 Graph* owningGraph() {
410 return graph_;
411 }
owningGraphNode412 const Graph* owningGraph() const {
413 return graph_;
414 }
owningBlockNode415 Block* owningBlock() {
416 return owning_block_;
417 }
owningBlockNode418 const Block* owningBlock() const {
419 return owning_block_;
420 }
scopeNode421 ScopePtr scope() {
422 return scope_;
423 }
setScopeNode424 void setScope(ScopePtr scope) {
425 scope_ = std::move(scope);
426 }
scopeNameNode427 std::string scopeName() const {
428 if (!scope_) {
429 return "";
430 }
431 return scope_->namesFromRoot();
432 }
433
434 // Copies the source range, scope and callstack from another node.
copyMetadataNode435 Node* copyMetadata(Node* from) {
436 this->setSourceRange(from->sourceRange());
437 this->setScope(from->scope());
438 if (auto cs = from->callstack()) {
439 this->setCallStack(*cs);
440 }
441 return this;
442 }
443
callstackNode444 std::optional<InlinedCallStackPtr> callstack() const {
445 return callstack_;
446 }
setCallStackNode447 void setCallStack(InlinedCallStackPtr cs) {
448 callstack_ = std::move(cs);
449 }
450
451 // NB: This returns an ArrayRef; that means that it will
452 // get invalidated if you resize inputs (e.g., using addInput)
453 // We can't return a std::vector<Node*>& because there's no
454 // way to soundly cast to std::vector<const Node*> (an insane
455 // implementation of std::vector could make this representationally
456 // different.)
inputsNode457 at::ArrayRef<Value*> inputs() {
458 return inputs_;
459 }
inputsNode460 at::ArrayRef<const Value*> inputs() const {
461 // Vectors are not convertible in const-ness of elements, but
462 // raw pointers are.
463 return {inputs_.data(), inputs_.size()};
464 }
465 // NB: This returns an ArrayRef; that means that it will
466 // get invalidated if you resize inputs (e.g., using addInput)
467 // We can't return a std::vector<Node*>& because there's no
468 // way to soundly cast to std::vector<const Node*> (an insane
469 // implementation of std::vector could make this representationally
470 // different.)
outputsNode471 at::ArrayRef<Value*> outputs() {
472 return outputs_;
473 }
outputsNode474 at::ArrayRef<const Value*> outputs() const {
475 // Vectors are not convertible in const-ness of elements, but
476 // raw pointers are.
477 return {outputs_.data(), outputs_.size()};
478 }
outputNode479 Value* output(size_t i) const {
480 return outputs_.at(i);
481 }
hasUsesNode482 bool hasUses() const {
483 for (auto o : outputs()) {
484 if (!o->uses().empty()) {
485 return true;
486 }
487 }
488 return false;
489 }
490
491 void replaceAllUsesWith(Node* n);
492
493 // replaces `this` with a new node with the same inputs and outputs
494 // but a new node symbol. does not destroy `this`
495 Node* replaceWithNewSymbol(Symbol new_symbol);
496
497 // Checks if this node is dominated by `dominator` which means that
498 // `dominator` will always be executed before `this` and `dominator`
499 // is in scope of `this.
500 bool isDominatedBy(const Node* dominator) const;
501
502 // lots of things like chunk have a single input or single output, so we have
503 // a helper to make accessing it easier
inputNode504 Value* input() {
505 AT_ASSERT(inputs_.size() == 1);
506 return inputs_.at(0);
507 }
outputNode508 Value* output() {
509 AT_ASSERT(outputs_.size() == 1);
510 return outputs_.at(0);
511 }
outputNode512 const Value* output() const {
513 AT_ASSERT(outputs_.size() == 1);
514 return outputs_.at(0);
515 }
inputNode516 const Value* input() const {
517 AT_ASSERT(inputs_.size() == 1);
518 return inputs_.at(0);
519 }
520 // Access a particular input. This is a checked index.
inputNode521 Value* input(size_t i) const {
522 return inputs_.at(i);
523 }
524
525 bool hasNamedInput(const std::string& unqualName) const;
526 Value* namedInput(const std::string& unqualName) const;
527 Value* namedInput(Symbol name) const;
528
529 std::optional<IValue> get(Symbol name) const;
530
531 template <typename T>
getNode532 std::optional<T> get(Symbol name) const {
533 if (auto v = get(name)) {
534 return v->template to<T>();
535 }
536 return std::nullopt;
537 }
538
539 // Returns true if the value of input name is statically known
is_constantNode540 bool is_constant(Symbol name) const {
541 return static_cast<bool>(get(name));
542 }
543 bool mustBeNone() const;
544
545 bool isNondeterministic() const;
546 bool hasSideEffects() const;
547
548 // instructions lowered by the interpreter and not run in the optimized graph
notExecutedOpNode549 bool notExecutedOp() const {
550 return kind_ == prim::Constant || kind_ == prim::profile ||
551 kind_ == prim::profile_ivalue;
552 }
553
554 // Graphs
555
556 // Note [Topological invariant]
557 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
558 // We always maintain an up-to-date topological ordering of all nodes via
559 // the next()/prev() links. All transformations to graphs must preserve
560 // this topological ordering: for example, it is only valid to 'addInput'
561 // with an input which is topologically before the current node.
562 //
563 // Usually, it is obvious whether or not topological order is maintained;
564 // for example, if you are adding nodes to the end of the topsort, it's
565 // impossible for them to refer to inputs that are not in the topsort.
566 // If it is not obvious, please comment accordingly.
567
568 // Add 'node' as an input to 'this' at the end of existing
569 // arguments. Returns the added node for ease of chaining.
570 //
571 // Given: %3 = f(%1, %2)
572 // Execute: %3.addInput(%4)
573 // Result: %3 = f(%1, %2, %4)
574 Value* addInput(Value* value);
575
576 // Add 'value' as an input to 'this' at the specified position in the
577 // arguments. Returns the added value for ease of chaining.
578 Value* insertInput(size_t i, Value* value);
579
580 // Replace the input of 'this' at position 'i' with
581 // 'newValue', returning the old node.
582 //
583 // Given: %3 = f(%1, %2)
584 // Execute: %3.replaceInput(1, %4)
585 // Result: %3 = f(%1, %4)
586 Value* replaceInput(size_t i, Value* newValue);
587
588 // Replace all occurrences of 'from' in the inputs of this
589 // node with 'to'. Corresponds to llvm's replaceUsesOfWith.
590 //
591 // Given: %3 = f(%1, %2, %1)
592 // Execute: %3.replaceInputWith(%1, %4)
593 // Result: %3 = f(%4, %2, %4)
594 void replaceInputWith(Value* from, Value* to);
595
596 Value* addOutput();
597
598 Value* insertOutput(size_t i);
599
600 void eraseOutput(size_t i);
601
602 Block* addBlock();
603 void eraseBlock(size_t i);
604
605 // Each Node can have a list of subblocks. These are used to define structured
606 // nested control flow operators such as If and Loop.
607 // The meaning of a block is specific to the kind of node it is in, but
608 // all blocks share these semantics:
609 // * Nested lexical scoping: If a node 'Parent' has a subblock which contains
610 // a node 'Child', Child can use any value that was in scope for the Parent
611 // node in addition to any values defined before 'Child' in the subblock.
612 // * The list of inputs to the block are in scope for the duration of the
613 // block
614 // * the outputs of the Parent node are not in scope for the subblocks
615 // Typically the inputs to a block that represents control flow act as
616 // as the equivalents phi-nodes in standard SSA form,
617 // defining a new Value to represent any term that has multiple
618 // definitions depending on how control flowed. Outputs of the node containing
619 // control flow serve a similiar purpose defining new values for variables
620 // that would have different definitions depending on which way control
621 // flowed.
622
blocksNode623 at::ArrayRef<Block*> blocks() {
624 return blocks_;
625 }
blocksNode626 at::ArrayRef<const Block*> blocks() const {
627 // Vectors are not convertible in const-ness of elements, but
628 // raw pointers are.
629 return {blocks_.data(), blocks_.size()};
630 }
631
632 // Is 'this' before 'n' in the topological order?
633 bool isBefore(const Node* n) const;
634
635 // Is 'this' after 'n' in the topological order?
636 bool isAfter(const Node* n) const;
637
638 // Insert unattached 'this' node before 'n' in the topological order.
639 // Returns this (for chaining).
640 //
641 // Given: %3 = f(%1, %2)
642 // %4 = g(%3)
643 // and unattached: %5 = h(%1)
644 // Execute: %5.insertBefore(%4)
645 // Result: %3 = f(%1, %2)
646 // %5 = h(%1)
647 // %4 = g(%3)
648 Node* insertBefore(Node* n);
649
650 // Insert unattached 'this' node after 'n' in the topological order.
651 // Returns this (for chaining).
652 //
653 // Given: %3 = f(%1, %2)
654 // %4 = g(%3)
655 // and unattached: %5 = h(%1)
656 // Execute: %5.insertAfter(%4)
657 // Result: %3 = f(%1, %2)
658 // %4 = g(%3)
659 // %5 = h(%1)
660 Node* insertAfter(Node* n);
661
662 // Move 'this' (already in the graph) after 'n' in the topological order.
663 //
664 // NOTE: Does not check that value dependencies are preserved, see
665 // AliasDb::moveAfterTopologicallyValid
666 //
667 // Given: %2 = f(%1)
668 // %3 = g(%1)
669 // Execute: %2.moveAfter(%3)
670 // Result: %3 = g(%1)
671 // %2 = f(%1)
672 //
673 void moveAfter(Node* n);
674
675 // Move a node 'n' (already in the graph) before 'this' in the topological
676 // order.
677 //
678 // NOTE: Does not check that value dependencies are preserved, see
679 // AliasDb::moveBeforeTopologicallyValid
680 //
681 // Given: %2 = f(%1)
682 // %3 = g(%1)
683 // Execute: %3.moveBefore(%2)
684 // Result: %3 = g(%1)
685 // %2 = f(%1)
686 void moveBefore(Node* n);
687
688 // Remove the input at 'i' from this node.
689 //
690 // WARNING: This is O(n) in the number of inputs, so avoid repeatedly calling
691 // removeInput.
692 //
693 // Given: %3 = f(%1, %2)
694 // Execute: %3.removeInput(1)
695 // Result: %3 = f(%1)
696 void removeInput(size_t i);
697
698 // Remove all inputs from a node.
699 //
700 // Given: %3 = f(%1, %2)
701 // Execute: %3.removeAllInputs()
702 // Result: %3 = f()
703 void removeAllInputs();
704
705 // Remove all outputs from a node.
706 //
707 // Given: %1, %2 = f()
708 // Execute:removeAllInputs()
709 // Result: = f()
710 void removeAllOutputs();
711
712 // Rearrange the ordering of inputs or outputs of a node
713 // Given: %3 = f(%1, %2)
714 // Execute: %3.permuteInputs({1, 0})
715 // Result: %3 = f(%2, %1)
716 // Each index must appear exactly once
717 void permuteInputs(const std::vector<size_t>& new_inputs);
718 void permuteOutputs(const std::vector<size_t>& new_inputs);
719
720 // iterators of the node list starting at this node
721 // useful for resuming a search starting at this node
iteratorNode722 inline graph_node_list_iterator iterator() {
723 return {this, 0};
724 }
reverseIteratorNode725 inline graph_node_list_iterator reverseIterator() {
726 return iterator().reverse();
727 }
iteratorNode728 inline const_graph_node_list_iterator iterator() const {
729 return {this, 0};
730 }
reverseIteratorNode731 inline const_graph_node_list_iterator reverseIterator() const {
732 return iterator().reverse();
733 }
734
735 // Remove 'this' from the instruction list and deallocate it.
736 //
737 // Invariant: no outputs of 'this' may have any uses.
738 //
739 // Given: %2 = f(%1)
740 // %3 = g(%1)
741 // Execute: %2.destroy()
742 // Result: %3 = g(%1)
743 void destroy();
744
745 // Dynamically cast this node to the subclass indicated by the
746 // template variable, returning nullptr if the cast is invalid..
747 //
748 // Example usage: if(auto s = n.cast<Select>()) { ... }
749 template <typename T>
castNode750 T* cast() {
751 if (T::Kind == kind()) {
752 return static_cast<T*>(this);
753 }
754 return nullptr;
755 }
756 template <typename T>
castNode757 const T* cast() const {
758 if (T::Kind == kind()) {
759 return static_cast<const T*>(this);
760 }
761 return nullptr;
762 }
763
764 template <typename T>
expectNode765 T* expect() {
766 TORCH_CHECK(
767 T::Kind == kind(),
768 "expected a ",
769 T::Kind.toDisplayString(),
770 " but found a ",
771 kind().toDisplayString());
772 return static_cast<T*>(this);
773 }
774
775 bool matches(const FunctionSchema& schema) const;
776
777 // XXX: this function is meant to be used with string literals only!
778 bool matches(
779 const char* signature_literal,
780 at::ArrayRef<Symbol> const_inputs = {}) const;
781
782 bool isMemberOf(const OperatorSet& os) const;
783 template <typename T>
isMemberOfNode784 bool isMemberOf(const OperatorMap<T>& om) const {
785 auto it = om.map.find(kind());
786 if (it == om.map.end()) {
787 return false;
788 }
789 for (auto& op : it->second) {
790 if (matches(op.first->schema())) {
791 return true;
792 }
793 }
794 return false;
795 }
796
797 const FunctionSchema& schema() const;
798 const FunctionSchema* maybeSchema() const;
799 const Operator& getOperator() const;
800 Operation getOperation() const;
801
802 const Operator* maybeOperator() const;
803
804 void dump() const;
805
806 std::ostream& print(
807 std::ostream& out,
808 size_t level,
809 std::vector<const Node*>* groups,
810 bool print_source_locations = true,
811 bool print_attributes = true,
812 bool print_scopes = true,
813 bool print_body = true) const;
814
~NodeNode815 virtual ~Node() {
816 if (wrap_) {
817 wrap_->clear();
818 }
819 }
820
821 // Methods for accessing attributes
copyAttributesNode822 Node* copyAttributes(const Node& rhs) {
823 values_.clear();
824 for (const AVPtr& i : rhs.values_) {
825 values_.push_back(i->clone());
826 }
827 return this;
828 }
hasAttributeNode829 bool hasAttribute(Symbol name) const {
830 AT_ASSERT(name.is_attr());
831 return findAttr(name, false) != values_.end();
832 }
hasAttributeSNode833 bool hasAttributeS(const std::string& name) const {
834 return hasAttribute(Symbol::attr(name));
835 }
kindOfNode836 AttributeKind kindOf(Symbol name) const {
837 AT_ASSERT(name.is_attr());
838 return (*findAttr(name, true))->kind();
839 }
kindOfSNode840 AttributeKind kindOfS(const std::string& name) const {
841 return kindOf(Symbol::attr(name));
842 }
removeAttributeNode843 Node* removeAttribute(Symbol name) {
844 AT_ASSERT(name.is_attr());
845 values_.erase(findAttr(name, true));
846 return this;
847 }
removeAttributeSNode848 Node* removeAttributeS(const std::string& name) {
849 return removeAttribute(Symbol::attr(name));
850 }
hasAttributesNode851 bool hasAttributes() const {
852 return !values_.empty();
853 }
numAttributesNode854 size_t numAttributes() const {
855 return values_.size();
856 }
857 // The names are returned in order, since name actually is the index.
attributeNamesNode858 std::vector<Symbol> attributeNames() const {
859 std::vector<Symbol> names;
860 names.reserve(values_.size());
861 for (const AVPtr& a : values_) {
862 names.push_back(a->name);
863 }
864 return names;
865 }
attributeNamesSNode866 std::vector<const char*> attributeNamesS() const {
867 std::vector<const char*> names;
868 names.reserve(values_.size());
869 for (const AVPtr& a : values_) {
870 names.push_back(a->name.toUnqualString());
871 }
872 return names;
873 }
874
875 #define CREATE_ACCESSOR(Kind, method) \
876 Node* method##_(Symbol name, Kind##Attr::ConstructorType v) { \
877 return setAttr<Kind##Attr>( \
878 name, std::forward<Kind##Attr::ConstructorType>(v)); \
879 } \
880 const Kind##Attr::ValueType& method(Symbol name) const { \
881 return getAttr<Kind##Attr>(name); \
882 }
883
CREATE_ACCESSORNode884 CREATE_ACCESSOR(Float, f)
885 CREATE_ACCESSOR(Complex, c)
886 CREATE_ACCESSOR(Floats, fs)
887 CREATE_ACCESSOR(ComplexVals, cs)
888 CREATE_ACCESSOR(String, s)
889 CREATE_ACCESSOR(Strings, ss)
890 CREATE_ACCESSOR(Int, i)
891 CREATE_ACCESSOR(Ints, is)
892 CREATE_ACCESSOR(Graph, g)
893 CREATE_ACCESSOR(Graphs, gs)
894 CREATE_ACCESSOR(Type, ty)
895 CREATE_ACCESSOR(Types, tys)
896 CREATE_ACCESSOR(IValue, ival)
897
898 #undef CREATE_ACCESSOR
899
900 // Our Graphs are not very const-correct, so we need to allow returning
901 // non-const references too
902 GraphAttr::ValueType& g(Symbol name) {
903 return getAttr<GraphAttr>(name);
904 }
905
906 // does not use CREATE_ACCESSOR because we need additional asserts
t_Node907 Node* t_(Symbol name, TensorAttr::ConstructorType v) {
908 return setAttr<TensorAttr>(
909 name, std::forward<TensorAttr::ConstructorType>(v));
910 }
tNode911 const TensorAttr::ValueType& t(Symbol name) const {
912 return getAttr<TensorAttr>(name);
913 }
914
ts_Node915 Node* ts_(Symbol name, TensorsAttr::ConstructorType v) {
916 return setAttr<TensorsAttr>(
917 name, std::forward<TensorsAttr::ConstructorType>(v));
918 }
tsNode919 const TensorsAttr::ValueType& ts(Symbol name) const {
920 return getAttr<TensorsAttr>(name);
921 }
922
923 Block* findCommonAncestorBlockWith(Node* n);
924
925 size_t blocksFromGraphBlock();
926
927 private:
928 void printAttrValue(std::ostream& out, const Symbol& name) const;
929 void printAttributes(std::ostream& out, bool ignore_subgraph) const;
930
931 template <typename T>
setAttrNode932 Node* setAttr(Symbol name, typename T::ConstructorType v) {
933 AT_ASSERT(name.is_attr());
934 auto it = findAttr(name, false);
935 auto nv = AVPtr(new T(name, std::forward<typename T::ConstructorType>(v)));
936 // NOLINTNEXTLINE(bugprone-branch-clone)
937 if (it == values_.end()) {
938 values_.push_back(std::move(nv));
939 } else {
940 *it = std::move(nv);
941 }
942 return this;
943 }
944 template <typename T>
getAttrNode945 typename T::ValueType& getAttr(Symbol name) const {
946 AT_ASSERT(name.is_attr());
947 auto it = findAttr(name, true);
948 auto* child = dynamic_cast<T*>(it->get());
949 if (child == nullptr) {
950 throw IRAttributeError(name, true);
951 }
952 return child->value();
953 }
954 using AVPtr = AttributeValue::Ptr;
955 // NB: For determinism, we use a vector rather than a hash map. This does
956 // mean that lookups are O(n), so you shouldn't use Attributes to store
957 // a big pile of messages.
958 std::vector<AVPtr> values_;
findAttrNode959 std::vector<AVPtr>::iterator findAttr(Symbol name, bool required) {
960 AT_ASSERT(name.is_attr());
961 auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) {
962 return v->name == name;
963 });
964 if (required && it == values_.end()) {
965 throw IRAttributeError(name, false);
966 }
967 AT_ASSERT(!required || it != values_.end());
968 return it;
969 }
findAttrNode970 std::vector<AVPtr>::const_iterator findAttr(Symbol name, bool required)
971 const {
972 AT_ASSERT(name.is_attr());
973 auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) {
974 return v->name == name;
975 });
976 if (required && it == values_.end()) {
977 throw IRAttributeError(name, false);
978 }
979 AT_ASSERT(!required || it != values_.end());
980 return it;
981 }
982
983 enum class MoveSide { BEFORE, AFTER };
984 bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
985
986 std::pair<Value*, const Argument&> findInput(Symbol name);
987 // Lookup iterator in use list of _input i_ that corresponds to its use of
988 // _this_
989 use_list::iterator findUseForInput(size_t i);
990
991 // remove the use of input i, this sets input i to nullptr, but
992 // is only used internally to Node before setting it to a new value
993 // or erasing the entry from the list.
994 Value* dropInput(size_t i);
995
inBlockListNode996 bool inBlockList() const {
997 if (next() == nullptr) {
998 AT_ASSERT(prev() == nullptr);
999 }
1000 return next() != nullptr;
1001 }
1002
1003 void removeFromList();
1004 void lint() const;
1005
1006 void assignTopoPosition();
1007
1008 protected:
1009 // subclasses must override
1010 // this function is used by createClone to initialize a new version
1011 // of a node in another graph. It should allocate a new instance of the same
1012 // concrete type as 'this', but in graph 'g' which might be different
1013 // than graph_
allocNewInstanceNode1014 virtual Node* allocNewInstance(Graph* g) {
1015 return new Node(g, kind());
1016 }
1017 // create a copy of all properties of Node s into this.
1018 // subclasses should extend if they have additional information to copy.
1019 // 'this' will be allocated with s->allocNewInstance(g) so it should have
1020 // the same concrete type as 's'
1021 virtual void cloneFrom(Node* s);
1022 };
1023
1024 struct Block {
1025 friend struct Node;
1026 friend struct Graph;
1027
1028 AT_DISALLOW_COPY_AND_ASSIGN(Block);
1029 TORCH_API Block(Graph* graph_, Node* node_);
1030
inputsBlock1031 at::ArrayRef<Value*> inputs() {
1032 return input_->outputs();
1033 }
inputsBlock1034 at::ArrayRef<const Value*> inputs() const {
1035 const auto& inputs = input_->outputs();
1036 return {inputs.data(), inputs.size()};
1037 }
outputsBlock1038 at::ArrayRef<Value*> outputs() {
1039 return output_->inputs();
1040 }
outputsBlock1041 at::ArrayRef<const Value*> outputs() const {
1042 return static_cast<const Node*>(output_)->inputs();
1043 }
nodesBlock1044 graph_node_list nodes() {
1045 return {input_, kNextDirection};
1046 }
nodesBlock1047 const_graph_node_list nodes() const {
1048 return {input_, kNextDirection};
1049 }
return_nodeBlock1050 Node* return_node() {
1051 return output_;
1052 }
return_nodeBlock1053 const Node* return_node() const {
1054 return output_;
1055 }
param_nodeBlock1056 Node* param_node() {
1057 return input_;
1058 }
param_nodeBlock1059 const Node* param_node() const {
1060 return input_;
1061 }
1062 /**
1063 * @warning NEVER pass raw pointer of smart pointer managed Graph to Python.
1064 * Check #87343 for details.
1065 */
owningGraphBlock1066 Graph* owningGraph() {
1067 return graph_;
1068 }
owningGraphBlock1069 const Graph* owningGraph() const {
1070 return graph_;
1071 }
owningNodeBlock1072 Node* owningNode() {
1073 return owning_node_;
1074 }
owningNodeBlock1075 const Node* owningNode() const {
1076 return owning_node_;
1077 }
1078
1079 Value* addInput(const std::string& name = "") {
1080 Value* v = input_->addOutput();
1081 v->setDebugName(name);
1082 return v;
1083 }
1084 Value* insertInput(size_t i, const std::string& name = "") {
1085 Value* v = input_->insertOutput(i);
1086 v->setDebugName(name);
1087 return v;
1088 }
eraseInputBlock1089 void eraseInput(size_t i) {
1090 input_->eraseOutput(i);
1091 }
removeAllInputsBlock1092 void removeAllInputs() {
1093 input_->removeAllOutputs();
1094 }
registerOutputBlock1095 size_t registerOutput(Value* v) {
1096 output_->addInput(v);
1097 return outputs().size() - 1;
1098 }
insertOutputBlock1099 size_t insertOutput(size_t i, Value* n) {
1100 output_->insertInput(i, n);
1101 return i;
1102 }
eraseOutputBlock1103 void eraseOutput(size_t i) {
1104 output_->removeInput(i);
1105 }
removeAllOutputsBlock1106 void removeAllOutputs() {
1107 output_->removeAllInputs();
1108 }
1109
replaceOutputBlock1110 void replaceOutput(size_t i, Value* n) {
1111 output_->replaceInput(i, n);
1112 }
permuteOutputsBlock1113 void permuteOutputs(const std::vector<size_t>& new_inputs) {
1114 output_->permuteInputs(new_inputs);
1115 }
permuteInputsBlock1116 void permuteInputs(const std::vector<size_t>& new_inputs) {
1117 input_->permuteOutputs(new_inputs);
1118 }
1119
appendNodeBlock1120 Node* appendNode(Node* n) {
1121 AT_ASSERT(n->graph_ == graph_ && !n->inBlockList());
1122 n->insertBefore(output_);
1123 return n;
1124 }
prependNodeBlock1125 Node* prependNode(Node* n) {
1126 AT_ASSERT(n->graph_ == graph_ && !n->inBlockList());
1127 n->insertAfter(input_);
1128 return n;
1129 }
1130
1131 // clone all inputs, nodes, and outputs from src and append them
1132 // to the inputs, nodes, and outputs of this block
1133 // value_map is used whenever a node in src references a free variable
1134 // in src to look up its corresponding value
1135 TORCH_API void cloneFrom(Block* src, std::function<Value*(Value*)> value_map);
1136 TORCH_API void remapTypes(const std::function<TypePtr(TypePtr)>& type_map);
1137
wrapBlock1138 TORCH_API std::shared_ptr<Wrap<Block>> wrap() {
1139 if (!wrap_) {
1140 wrap_ = std::make_shared<Wrap<Block>>(this);
1141 }
1142 return wrap_;
1143 }
1144
~BlockBlock1145 virtual ~Block() {
1146 if (wrap_) {
1147 wrap_->clear();
1148 }
1149 }
1150
clearBlock1151 void clear() {
1152 removeAllOutputs();
1153 for (auto it = nodes().rbegin(); it != nodes().rend(); it++) {
1154 it.destroyCurrent();
1155 }
1156 removeAllInputs();
1157 }
1158
1159 private:
1160 void reIndexTopology();
1161
1162 // get rid of all nodes
1163 // destroys in reverse order so that uses internal to this block
1164 // do not have to be removed before you can destroy the block
1165 void destroy();
1166
1167 Graph* const graph_;
1168 // holds outputs in a way that can be reflected
1169 // as a Use object
1170 // also used as the beginning/end of the circular node list to avoid
1171 // having corner cases where the list is empty.
1172 Node* const output_;
1173 Node* const input_;
1174 Node* const
1175 owning_node_; // either the node that has this block or nullptr for root
1176 // a managing wrapper for Python to allow invalidation
1177 std::shared_ptr<Wrap<Block>> wrap_;
1178 };
1179
1180 struct Graph : std::enable_shared_from_this<Graph> {
1181 AT_DISALLOW_COPY_AND_ASSIGN(Graph);
1182 friend struct Node;
1183 friend struct Value;
1184 friend struct Block;
1185
1186 private:
1187 // only used to keep track of allocated nodes
1188 // actual representation of Graph is done with
1189 // inputs, outputs, nodes
1190
1191 std::unordered_set<const Node*> all_nodes;
1192 std::unordered_set<const Value*> all_values;
1193 std::unordered_set<const Block*> all_blocks;
1194 size_t next_unique_{0};
1195
1196 std::unordered_map<std::string, Value*> unique_names_;
1197 // name_base_suffix tracks largest suffix currently used by all names sharing
1198 // same name_base. Key of this map is name_base, value is largest suffix
1199 // numeric value.
1200 std::unordered_map<std::string, size_t> name_base_suffix_;
1201
1202 ScopePtr current_scope_;
1203
1204 Block* const block_;
1205 // when insertNode() is called, the node is inserted before this node
1206 // by default this is set to append to the top level block
1207 Node* insert_before_;
1208 int64_t predicted_insert_count_ = 0;
1209
1210 std::optional<size_t> op_version_;
1211
1212 public:
1213 Graph(ScopePtr scope_root = c10::make_intrusive<Scope>())
current_scope_Graph1214 : current_scope_(std::move(scope_root)),
1215 block_(new Block(this, nullptr)),
1216 insert_before_(return_node()) {}
1217
inputsGraph1218 at::ArrayRef<Value*> inputs() {
1219 return block_->inputs();
1220 }
inputsGraph1221 at::ArrayRef<const Value*> inputs() const {
1222 const Block& block = *block_;
1223 return block.inputs();
1224 }
outputsGraph1225 at::ArrayRef<Value*> outputs() {
1226 return block_->outputs();
1227 }
outputsGraph1228 at::ArrayRef<const Value*> outputs() const {
1229 const Block& block = *block_;
1230 return block.outputs();
1231 }
nodesGraph1232 graph_node_list nodes() {
1233 return block_->nodes();
1234 }
nodesGraph1235 const_graph_node_list nodes() const {
1236 const Block& block = *block_;
1237 return block.nodes();
1238 }
param_nodeGraph1239 Node* param_node() {
1240 return block_->param_node();
1241 }
param_nodeGraph1242 const Node* param_node() const {
1243 return block_->param_node();
1244 }
return_nodeGraph1245 Node* return_node() {
1246 return block_->return_node();
1247 }
return_nodeGraph1248 const Node* return_node() const {
1249 return block_->return_node();
1250 }
debugNamesGraph1251 const std::unordered_map<std::string, Value*>& debugNames() const {
1252 return unique_names_;
1253 }
1254
1255 TORCH_API void push_scope(const std::string& scope_name);
1256 TORCH_API void pop_scope();
1257
current_scopeGraph1258 ScopePtr current_scope() {
1259 return current_scope_;
1260 }
1261
set_op_versionGraph1262 void set_op_version(std::optional<size_t> version) {
1263 op_version_ = version;
1264 }
1265
get_op_versionGraph1266 std::optional<size_t> get_op_version() {
1267 return op_version_;
1268 }
1269
set_current_scopeGraph1270 void set_current_scope(ScopePtr scope) {
1271 current_scope_ = std::move(scope);
1272 }
1273
1274 Value* addInput(const std::string& name = "") {
1275 return block_->addInput(name);
1276 }
1277 Value* insertInput(size_t i, const std::string& name = "") {
1278 return block_->insertInput(i, name);
1279 }
eraseInputGraph1280 void eraseInput(size_t i) {
1281 block_->eraseInput(i);
1282 }
registerOutputGraph1283 size_t registerOutput(Value* n) {
1284 return block_->registerOutput(n);
1285 }
eraseOutputGraph1286 void eraseOutput(size_t i) {
1287 block_->eraseOutput(i);
1288 }
1289
1290 TORCH_API Node* create(NodeKind kind, size_t num_outputs = 1);
1291 TORCH_API Node* create(
1292 NodeKind kind,
1293 ArrayRef<Value*> inputs,
1294 size_t num_outputs = 1);
1295
1296 TORCH_API Node* createNone();
1297 TORCH_API Node* createAutogradZero();
1298 TORCH_API Node* createUninitialized(TypePtr typ);
1299 TORCH_API Node* createWithSubgraph(Symbol kind);
1300 TORCH_API Node* createDifferentiableSubgraph();
1301 TORCH_API Node* createTuple(
1302 at::ArrayRef<Value*> values,
1303 TupleTypePtr optional_named_tuple = nullptr);
1304 TORCH_API Node* createTupleUnpack(Value* v);
1305 TORCH_API Node* createTupleIndex(
1306 Value* tup,
1307 Value* idx,
1308 const TypePtr& output_type);
1309 TORCH_API Node* createTupleSlice(
1310 Value* tup,
1311 int64_t beg,
1312 int64_t step_size,
1313 int64_t num_values);
1314 TORCH_API Node* createEnumName(Value* e);
1315 TORCH_API Node* createEnumValue(Value* e);
1316 TORCH_API Node* createList(
1317 const TypePtr& contained_type,
1318 at::ArrayRef<Value*> values);
1319 TORCH_API Node* createListUnpack(Value* v, size_t size);
1320 TORCH_API Node* createDict(
1321 const TypePtr& key_type,
1322 const TypePtr& value_type,
1323 at::ArrayRef<Value*> keys,
1324 at::ArrayRef<Value*> values);
1325 TORCH_API Node* createNumToTensor(Value* value);
1326 TORCH_API Node* createObject(const ClassTypePtr& type);
1327 TORCH_API Node* createSetAttr(
1328 Value* obj,
1329 const std::string& field,
1330 Value* newValue);
1331 TORCH_API Node* createGetAttr(Value* obj, const std::string& field);
insertGetAttrGraph1332 Value* insertGetAttr(Value* obj, const std::string& field) {
1333 return insertNode(createGetAttr(obj, field))->output();
1334 }
1335 TORCH_API Node* createStore(const std::string& name, Value* v);
1336 TORCH_API Node* createLoad(const std::string& name, const TypePtr& type);
1337 TORCH_API Node* createIsInstance(Value* v, at::ArrayRef<TypePtr> types);
1338
1339 TORCH_API Value* insertUncheckedCast(Value* v, TypePtr type);
1340
1341 // Insert a ToList operator with argument \p v and output type \p type.
1342 // \returns the output of the operation.
1343 TORCH_API Value* insertToList(Value* v, TypePtr type);
1344
1345 TORCH_API Value* insertFunctionCall(
1346 Function* callee,
1347 const MatchedSchema& matched);
1348 TORCH_API Value* insertMethodCall(
1349 std::string method_name,
1350 const MatchedSchema& matched);
1351
1352 // Note: defined in python_ir.cpp and can be used only in python extension
1353 Node* createPythonOp(
1354 THPObjectPtr&& pyobj,
1355 const std::string& cconv,
1356 pyobj_list&& scalar_args);
1357 // clone n, making a new node in _this_ graph.
1358 // use value_map to translate inputs of n to inputs of the cloned node
1359 // if copy_blocks is false, it will not recursively clone the nested blocks
1360 // this node contains.
1361 TORCH_API Node* createClone(
1362 Node* n,
1363 const std::function<Value*(Value*)>& value_map,
1364 bool copy_blocks = true);
1365
1366 // Insert constant IValue into the graph.
1367 TORCH_API Value* insertConstant(
1368 const IValue& val,
1369 std::optional<SourceRange> loc = std::nullopt,
1370 std::optional<ScopePtr> scope = std::nullopt);
1371
1372 // Schema-driven insert:
1373 // This inserts a node into the graph with inputs determined from args and
1374 // kwargs using Python argument matching rules, and checks that the op matches
1375 // a known schema.
1376 //
1377 // If this node successfully completes, it guarentees the node
1378 // is a correctly-formed invocation of opname
1379 TORCH_API Value* insert(
1380 Symbol opname,
1381 at::ArrayRef<NamedValue> args,
1382 at::ArrayRef<NamedValue> kwargs = {},
1383 const std::optional<SourceRange>& range = {});
1384
appendNodeGraph1385 Node* appendNode(Node* n) {
1386 return block_->appendNode(n);
1387 }
1388
prependNodeGraph1389 Node* prependNode(Node* n) {
1390 return block_->prependNode(n);
1391 }
1392
1393 // insert before insert_before_ node
1394 // initialized to insert at the end of the top level block
1395 // can be changed with setInsertPoint()
insertNodeGraph1396 Node* insertNode(Node* n) {
1397 AT_ASSERT(
1398 insert_before_->inBlockList() &&
1399 "insert point node is no longer in a block list");
1400 return n->insertBefore(insert_before_);
1401 }
1402 // set where nodes are inserted to append to the end of this block
setInsertPointGraph1403 void setInsertPoint(Block* b) {
1404 AT_ASSERT(b->owningGraph() == this);
1405 setInsertPoint(b->return_node());
1406 }
1407 // set where nodes are inserted to insert _before_ this node
1408 // for implementation simplicity we only support inserting before a node for
1409 // now
setInsertPointGraph1410 void setInsertPoint(Node* n) {
1411 AT_ASSERT(n->owningGraph() == this && n->inBlockList());
1412 insert_before_ = n;
1413 predicted_insert_count_ = 0;
1414 }
insertPointGraph1415 Node* insertPoint() {
1416 return insert_before_;
1417 }
1418
1419 // the top level block
blockGraph1420 Block* block() {
1421 return block_;
1422 }
blockGraph1423 const Block* block() const {
1424 return block_;
1425 }
1426
1427 // Checks well-formedness and invariants of graph
1428 TORCH_API void lint() const;
1429 // for use in debugger
1430 TORCH_API void dump() const;
1431
1432 TORCH_API ~Graph();
1433
1434 TORCH_API std::string toString(bool print_source_locations = true) const;
1435
1436 TORCH_API std::ostream& print(
1437 std::ostream& out,
1438 bool print_source_locations = true) const;
1439
1440 friend TORCH_API std::ostream& operator<<(std::ostream& out, const Graph& g);
1441
1442 TORCH_API std::shared_ptr<Graph> copy();
1443 TORCH_API std::unique_ptr<Graph> copyUnique();
1444 TORCH_API void remapTypes(const std::function<TypePtr(TypePtr)>& type_map);
1445
1446 private:
1447 friend TORCH_API void Lint(const AliasDb* db);
1448 TORCH_API void freeNode(Node* n);
1449 TORCH_API void freeValue(Value* v);
1450 TORCH_API void freeBlock(Block* b);
1451 void cloneFrom(Graph& src);
1452 };
1453
1454 /** \brief An utility class for setting temporary insertion points.
1455 *
1456 * When an object of this class is created, it stores the current insertion
1457 * point, sets the new one, and restores the original insertion point when the
1458 * object is destroyed.
1459 */
1460 struct WithInsertPoint {
WithInsertPointWithInsertPoint1461 WithInsertPoint(Node* n) : prev_(n->owningGraph()->insertPoint()) {
1462 n->owningGraph()->setInsertPoint(n);
1463 }
WithInsertPointWithInsertPoint1464 WithInsertPoint(Block* b) : WithInsertPoint(b->return_node()) {}
1465
~WithInsertPointWithInsertPoint1466 ~WithInsertPoint() {
1467 prev_->owningGraph()->setInsertPoint(prev_);
1468 }
1469
1470 private:
1471 Node* prev_;
1472 };
1473
1474 /** \brief An utility class for setting temporary scopes.
1475 *
1476 * When an object of this class is created, it stores the current scope, sets
1477 * the new one, and restores the original scope when the object is destroyed.
1478 */
1479 struct WithCurrentScope {
WithCurrentScopeWithCurrentScope1480 WithCurrentScope(Graph& g, ScopePtr scope)
1481 : graph_(&g), prev_scope_(g.current_scope()) {
1482 g.set_current_scope(std::move(scope));
1483 }
~WithCurrentScopeWithCurrentScope1484 ~WithCurrentScope() {
1485 graph_->set_current_scope(prev_scope_);
1486 }
1487
1488 private:
1489 Graph* graph_;
1490 ScopePtr prev_scope_;
1491 };
1492
1493 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Value(Node * node_,size_t offset_)1494 inline Value::Value(Node* node_, size_t offset_)
1495 : node_(node_),
1496 offset_(offset_),
1497 unique_(node_->graph_->next_unique_++),
1498 type_(TensorType::get()) {
1499 node_->graph_->all_values.emplace(this);
1500 }
1501
setType(TypePtr type)1502 inline Value* Value::setType(TypePtr type) {
1503 AT_ASSERT(type);
1504 if (auto dyn = type->castRaw<c10::DynamicType>()) {
1505 type = dyn->fallback();
1506 }
1507 type_ = std::move(type);
1508 for (Use& use : uses_) {
1509 use.user->op_ = nullptr;
1510 }
1511 return this;
1512 }
1513
owningGraph()1514 inline Graph* Value::owningGraph() {
1515 return node()->owningGraph();
1516 }
1517
owningGraph()1518 inline const Graph* Value::owningGraph() const {
1519 return node()->owningGraph();
1520 }
1521
1522 /************* All nodes not required to be defined before Graph **************/
1523 struct ProfileOp : public Node {
1524 static const Symbol Kind;
ProfileOpProfileOp1525 ProfileOp(Graph* graph, std::function<void(std::vector<IValue>&)> callback)
1526 : Node(graph, ::c10::prim::profile), callback_(std::move(callback)) {}
1527
1528 void cloneFrom(Node* other_) override;
1529 Node* allocNewInstance(Graph* g) override;
1530
getCallbackProfileOp1531 const std::function<void(std::vector<IValue>&)>& getCallback() const {
1532 return callback_;
1533 }
1534
setCallbackProfileOp1535 void setCallback(std::function<void(std::vector<IValue>&)> callback) {
1536 callback_ = std::move(callback);
1537 }
1538
hasSeenTensorProfileOp1539 bool hasSeenTensor() const {
1540 return has_seen_tensor_;
1541 }
1542
setHasSeenTensorProfileOp1543 void setHasSeenTensor(bool has_seen_tensor) {
1544 has_seen_tensor_ = has_seen_tensor;
1545 }
1546
1547 private:
1548 std::function<void(std::vector<IValue>&)> callback_;
1549 bool has_seen_tensor_ = false;
1550 };
1551
1552 struct TORCH_API ProfileIValueOp : public Node {
1553 static const Symbol Kind;
ProfileIValueOpProfileIValueOp1554 ProfileIValueOp(
1555 Graph* graph,
1556 std::function<void(std::vector<IValue>&)> callback)
1557 : Node(graph, ::c10::prim::profile_ivalue),
1558 callback_(std::move(callback)) {}
1559
1560 void cloneFrom(Node* other_) override;
1561 Node* allocNewInstance(Graph* g) override;
1562
getCallbackProfileIValueOp1563 const std::function<void(std::vector<IValue>&)>& getCallback() const {
1564 return callback_;
1565 }
1566
setCallbackProfileIValueOp1567 void setCallback(std::function<void(std::vector<IValue>&)> callback) {
1568 callback_ = std::move(callback);
1569 }
1570
1571 private:
1572 std::function<void(std::vector<IValue>&)> callback_;
1573 };
1574
1575 // execute a Python function, used for Ops we can't optimize but that we want to
1576 // optimize around
1577 //
1578 // Note: actual implementation (ConcretePythonOp) is defined in python_ir.cpp
1579 // which is not included in libtorch.so. We still include some bits and pieces
1580 // of PythonOp here to enable writing simple passes generically. In general,
1581 // python-aware bits need to be moved to the descendant classes.
1582 struct TORCH_API PythonOp : public Node {
1583 using Node::Node;
1584
1585 virtual std::string name() const = 0;
1586 virtual void writeScalars(std::ostream& out) const = 0;
1587 void cloneFrom(Node* other_) override = 0;
1588 Node* allocNewInstance(Graph* g) override = 0;
1589 // recover the autograd.Function instance, if this PythonOp's function
1590 // was originally SomeFunction.apply
1591 // used in ONNX for discovering symbolics
1592 virtual std::optional<THPObjectPtr> autogradFunction() const = 0;
1593
1594 virtual void lint_python() const = 0;
1595 };
1596
1597 TORCH_API void LintGraph(const std::shared_ptr<Graph>& graph);
1598
1599 TORCH_API at::ArrayRef<Value*> createTupleUnpack(Value* v);
1600
1601 /** Insert graph \p CALLEE into graph \p G using \p INPUTS as input values.
1602 * The insertion happens at the current insertion point.
1603 * Optionally, one can also pass \p VALUE_MAP to get a map between \p CALLEE
1604 * values and their cloned copies in \p G.
1605 */
1606 TORCH_API std::vector<Value*> insertGraph(
1607 Graph& g,
1608 Graph& callee,
1609 ArrayRef<Value*> inputs);
1610 TORCH_API std::vector<Value*> insertGraph(
1611 Graph& g,
1612 Graph& callee,
1613 ArrayRef<Value*> inputs,
1614 std::unordered_map<Value*, Value*>& value_map);
1615
1616 /** Insert function \p CALLEE after node \p TO_REPLACE, remove the node and
1617 * replace all its uses with corresponding outputs of the inserted function.
1618 * This asserts that the number of outputs of the original node and the
1619 * graph are the same.
1620 */
1621 TORCH_API std::vector<Value*> inlineCallTo(
1622 Node* to_replace,
1623 GraphFunction* callee,
1624 bool use_graph = true);
1625
1626 TORCH_API std::vector<Value*> inlineCallTo(
1627 Node* to_replace,
1628 GraphFunction* callee,
1629 Graph* callee_graph);
1630
1631 /** If there is only one value in \p OUTPUTS and its kind is Tuple, insert a
1632 * tuple unpack node and return the resulting values.
1633 */
1634 TORCH_API std::vector<Value*> unpackOutputs(const std::vector<Value*>& outputs);
1635
1636 TORCH_API std::vector<Node*> findAllNodes(Graph& g, Symbol kind, bool recurse);
1637 TORCH_API std::vector<Node*> findAllNodes(Block& b, Symbol kind, bool recurse);
1638 TORCH_API std::vector<Node*> findAllNodes(
1639 at::ArrayRef<Block*> a,
1640 Symbol kind,
1641 bool recurse);
1642
1643 struct TORCH_API OperatorSet {
1644 OperatorSet(std::initializer_list<const char*> sig_literals);
1645 std::vector<std::shared_ptr<Operator>> getOps() const;
1646 void insert(std::initializer_list<const char*> sig_literals);
1647
1648 private:
1649 friend struct Node;
1650 std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>> ops;
1651 };
1652
1653 template <typename T>
1654 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
1655 struct OperatorMap {
1656 // Type aliasing
1657 using OpMapType = typename std::pair<std::shared_ptr<Operator>, T>;
1658 using ValueType = std::vector<OpMapType>;
1659 using MapType = std::unordered_map<Symbol, ValueType>;
1660
1661 OperatorMap() = default;
1662 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
OperatorMapOperatorMap1663 explicit OperatorMap(
1664 std::initializer_list<std::pair<std::shared_ptr<Operator>, T>> init) {
1665 insert(init);
1666 }
1667 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
OperatorMapOperatorMap1668 explicit OperatorMap(std::initializer_list<std::pair<const char*, T>> init) {
1669 insert(init);
1670 }
1671
insertOperatorMap1672 void insert(const std::shared_ptr<Operator>& op, T val) {
1673 // Remove if exists before insert
1674 erase(op);
1675 map[Symbol::fromQualString(op->schema().name())].emplace_back(
1676 std::make_pair(op, val));
1677 }
1678
insertOperatorMap1679 void insert(const OperatorSet& op_set, T val) {
1680 for (auto& op : op_set.getOps()) {
1681 insert(op, val);
1682 }
1683 }
1684
insertOperatorMap1685 void insert(
1686 std::initializer_list<std::pair<std::shared_ptr<Operator>, T>> v) {
1687 for (auto& el : v) {
1688 insert(el.first, el.second);
1689 }
1690 }
1691
insertOperatorMap1692 void insert(std::initializer_list<std::pair<const char*, T>> v) {
1693 for (auto& el : v) {
1694 insert(getOperatorForLiteral(el.first), el.second);
1695 }
1696 }
1697
eraseOperatorMap1698 void erase(const std::shared_ptr<Operator>& op) {
1699 auto it = map.find(Symbol::fromQualString(op->schema().name()));
1700 if (it == map.end()) {
1701 return;
1702 }
1703 for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) {
1704 if (vit->first->schema() == op->schema()) {
1705 it->second.erase(vit);
1706 break;
1707 }
1708 }
1709 if (it->second.size() == 0) {
1710 map.erase(Symbol::fromQualString(op->schema().name()));
1711 }
1712 }
1713
containsOperatorMap1714 bool contains(const Operator& op) const {
1715 const auto it = map.find(Symbol::fromQualString(op.schema().name()));
1716 if (it == map.end()) {
1717 return false;
1718 }
1719 for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) {
1720 if (vit->first->schema() == op.schema()) {
1721 return true;
1722 }
1723 }
1724 return false;
1725 }
1726
containsOperatorMap1727 bool contains(const Node* n) const {
1728 return n->maybeOperator() && contains(n->getOperator());
1729 }
1730
findOperatorMap1731 std::optional<T> find(const Operator& op) {
1732 const auto it = map.find(Symbol::fromQualString(op.schema().name()));
1733 if (it == map.end()) {
1734 return std::nullopt;
1735 }
1736 for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) {
1737 if (vit->first->schema() == op.schema()) {
1738 return vit->second;
1739 }
1740 }
1741 return std::nullopt;
1742 }
1743
1744 // TODO: return iterator
getAllKeysAndValuesOperatorMap1745 std::vector<OpMapType> getAllKeysAndValues() const {
1746 std::vector<OpMapType> keys_values;
1747 keys_values.reserve(map.size());
1748 for (auto& symbol_mapping : map) {
1749 auto& vec = symbol_mapping.second;
1750 for (auto& pair : vec) {
1751 keys_values.push_back(pair);
1752 }
1753 }
1754 return keys_values;
1755 }
1756
1757 private:
1758 friend struct Node;
1759 MapType map;
1760 };
1761
1762 template <typename T>
1763 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
1764 struct FunctionSchemaMap {
1765 // Type aliasing
1766 using FuncSchemaMapType = typename std::pair<FunctionSchema, T>;
1767 using ValueType = std::vector<FuncSchemaMapType>;
1768 using MapType = std::unordered_map<Symbol, ValueType>;
1769
1770 FunctionSchemaMap() = default;
insertFunctionSchemaMap1771 void insert(const FunctionSchema& schema, T val) {
1772 // Remove if exists before insert
1773 erase(schema);
1774 map[Symbol::fromQualString(schema.name())].emplace_back(
1775 std::make_pair(schema, val));
1776 }
1777
eraseFunctionSchemaMap1778 void erase(const FunctionSchema& schema) {
1779 auto it = map.find(Symbol::fromQualString(schema.name()));
1780 if (it == map.end()) {
1781 return;
1782 }
1783 for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) {
1784 if (vit->first == schema) {
1785 it->second.erase(vit);
1786 break;
1787 }
1788 }
1789 if (it->second.size() == 0) {
1790 map.erase(Symbol::fromQualString(schema.name()));
1791 }
1792 }
1793
containsFunctionSchemaMap1794 bool contains(const FunctionSchema& schema) const {
1795 const auto it = map.find(Symbol::fromQualString(schema.name()));
1796 if (it == map.end()) {
1797 return false;
1798 }
1799 for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) {
1800 if (vit->first->schema() == schema) {
1801 return true;
1802 }
1803 }
1804 return false;
1805 }
1806
findFunctionSchemaMap1807 std::optional<T> find(const FunctionSchema& schema) const {
1808 const auto it = map.find(Symbol::fromQualString(schema.name()));
1809 if (it == map.end()) {
1810 return std::nullopt;
1811 }
1812 for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) {
1813 if (vit->first == schema) {
1814 return vit->second;
1815 }
1816 }
1817 return std::nullopt;
1818 }
1819
1820 // TODO: return iterator
getAllKeysAndValuesFunctionSchemaMap1821 std::vector<FuncSchemaMapType> getAllKeysAndValues() const {
1822 std::vector<FuncSchemaMapType> keys_values;
1823 keys_values.reserve(map.size());
1824 for (auto& symbol_mapping : map) {
1825 auto& vec = symbol_mapping.second;
1826 for (auto& pair : vec) {
1827 keys_values.push_back(pair);
1828 }
1829 }
1830 return keys_values;
1831 }
1832
1833 private:
1834 friend struct Node;
1835 MapType map;
1836 };
1837
1838 } // namespace torch::jit
1839