xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/ir/scope.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/jit_type.h>
3 #include <ATen/core/symbol.h>
4 #include <c10/util/intrusive_ptr.h>
5 #include <torch/csrc/Export.h>
6 #include <torch/csrc/jit/frontend/source_range.h>
7 #include <optional>
8 #include <unordered_map>
9 
10 namespace torch::jit {
11 struct ModuleInstanceInfo;
12 constexpr size_t kModuleInstanceInfo = 2;
13 
14 namespace utils {
15 std::string get_module_info(const ModuleInstanceInfo& module_instance_info);
16 } // namespace utils
17 
18 // Scope is a node of a trie that represents the tree of nested scopes.
19 // Individual scopes are pushed and popped from Graph, which holds a
20 // pointer to the current scope. Each Node in Graph holds a pointer
21 // to the scope that was current when the node was created.
22 // The trie never needs to shrink, it only grows until it is disposed
23 // of when Graph is deallocated. Hence, pointers to scopes held by nodes
24 // will always be valid as long as Graph is alive.
25 struct Scope;
26 using ScopePtr = c10::intrusive_ptr<Scope>;
27 using c10::Symbol;
28 
29 struct TORCH_API Scope : public c10::intrusive_ptr_target {
30  private:
31   ScopePtr parent_;
32   Symbol name_;
33   ScopePtr intrusive_from_this();
34 
35  public:
36   Scope();
37 
38   Scope(ScopePtr parent, Symbol name);
39 
40   ScopePtr push(Symbol name);
41 
42   ScopePtr parent();
43 
44   bool isRoot() const;
45 
46   bool isBlank() const;
47 
48   ScopePtr getRoot();
49 
50   size_t getDepth();
51 
52   Symbol name() const;
53 
54   std::string namesFromRoot(const std::string& separator = "/") const;
55 };
56 
57 struct Function;
58 struct InlinedCallStack;
59 
60 /**
61  * ModuleInstanceInfo is a structure to include the module type and instance
62  * name. It also provide public methods to get the pointer to module type and
63  * instance name.
64  *
65  * This structure is mainly used as a private member in InlinedCallStack, such
66  * that one can follow the callstack to find the relevant module hierarchy.
67  */
68 struct ModuleInstanceInfo {
69  private:
70   c10::ClassTypePtr module_type_{nullptr};
71   std::string instance_name_;
72 
73  public:
74   ModuleInstanceInfo() = default;
75   ModuleInstanceInfo(c10::ClassTypePtr module_type, std::string instance_name);
class_typeModuleInstanceInfo76   c10::ClassTypePtr class_type() {
77     return module_type_;
78   }
class_typeModuleInstanceInfo79   c10::ClassTypePtr class_type() const {
80     return module_type_;
81   }
instance_nameModuleInstanceInfo82   std::string instance_name() const {
83     return instance_name_;
84   }
85 
86   bool operator==(const ModuleInstanceInfo& rhs) const {
87     return (class_type() == rhs.class_type()) &&
88         (instance_name() == rhs.instance_name());
89   }
90 };
91 
92 /**
93  * InlinedCallStack is an element in a list representing callstack of functions
94  * that have been inlined.
95  *
96  * Each such element holds info about the current callsite (Function and
97  * SourceRange) and a pointer to the next element in the list. The last element
98  * in the list represents the innermost function that was inlined.
99  *
100  * For instance, if a node has a callstack
101  *    [foo, source_range1] -> [bar, source_range2]
102  * it means that this node was originally from function 'bar' that was called
103  * at 'source_range2' in function 'foo' that was called in the current function
104  * at 'source_range1'.
105  *
106  * If a node did not come from any inlined function, its callstack will be
107  * empty.
108  *
109  * The callstack lists only grow, we never remove elements from them, which
110  * allows us to reuse same elements in different lists. For instance, if we
111  * inline function 'bar' to 'foo' and then inline 'foo' to two functions 'ham'
112  * and 'baz', the callstacks would look like:
113  *
114  *  [baz, source_range3]  --
115  *                           \
116  *                             --> [foo, source_range1] -> [bar, source_range2]
117  *                           /
118  *  [ham, source_range4]  --
119  */
120 using InlinedCallStackPtr = c10::intrusive_ptr<InlinedCallStack>;
121 using InlinedCallStackEntry =
122     std::tuple<Function*, SourceRange, std::optional<ModuleInstanceInfo>>;
123 
124 struct TORCH_API InlinedCallStack : public c10::intrusive_ptr_target {
125  private:
126   std::optional<InlinedCallStackPtr> callee_;
127   Function* fn_;
128   // Reason for fn_name_ even though we have fn_
129   // Serialized callstack is used in circustmances where InlinedCallstack
130   // cannot be constructed during runtime, e.g. mobile runtime or
131   // delegated backends.
132   // Since in those cases we do not have Function* we store function name
133   // fn_name does not give you access to the same information that Function*
134   // does, however in mobile/delegated backend runtime we use InlindedCallStack
135   // for exception stack and for that purpose fn_name_ suffices.
136   const std::string fn_name_;
137   SourceRange source_range_;
138   InlinedCallStackPtr intrusive_from_this();
139   std::optional<ModuleInstanceInfo> module_instance_info_;
140 
141  public:
142   // Constructor for a leaf callstack node.
143   InlinedCallStack(Function* fn, SourceRange source_range);
144 
145   // Constructor for a leaf callstack node.
146   InlinedCallStack(
147       Function* fn,
148       SourceRange source_range,
149       std::optional<ModuleInstanceInfo> module_instance_info);
150 
151   // Constructor for a leaf callstack node.
152   InlinedCallStack(
153       Function* fn,
154       SourceRange source_range,
155       std::optional<ModuleInstanceInfo> module_instance_info,
156       std::string& function_name);
157 
158   // Constructor for an inner callstack node.
159   InlinedCallStack(
160       InlinedCallStackPtr callee,
161       Function* fn,
162       SourceRange source_range);
163 
164   InlinedCallStack(
165       InlinedCallStackPtr callee,
166       Function* fn,
167       SourceRange source_range,
168       std::optional<ModuleInstanceInfo> module_instance_info);
169 
170   InlinedCallStack(
171       InlinedCallStackPtr callee,
172       Function* fn,
173       SourceRange source_range,
174       std::optional<ModuleInstanceInfo> module_instance_info,
175       std::string& function_name);
176 
177   // Return next element in the callstack list.
178   std::optional<InlinedCallStackPtr> callee() const;
179 
180   // Return module instance associated with the current element.
181   std::optional<ModuleInstanceInfo> module_instance() const;
182 
183   // Returns the source range of the node
184   SourceRange source_range() const;
185 
186   Function* function() const;
187 
188   const std::string& function_name() const;
189 
190   // Return callstack as a vector of [Function, SourceRange] pairs.
191   std::vector<InlinedCallStackEntry> vec();
192 
193   void setCallee(std::optional<InlinedCallStackPtr>);
194 
195   bool operator==(const InlinedCallStack& rhs) const {
196     // No need to compare fn_, since source_range equivalence check
197     // should suffice.
198     return (module_instance().has_value() ==
199             rhs.module_instance().has_value()) &&
200         (module_instance().has_value() &&
201          module_instance().value() == rhs.module_instance().value()) &&
202         callee() == rhs.callee() && source_range() == rhs.source_range();
203   }
204 
205   bool operator!=(const InlinedCallStack& rhs) const {
206     return !(*this == rhs);
207   }
208 };
209 
210 // {source range, node name, InlinedCallStack}
211 // We store node name because same debug infor will be used for
212 // profiling as well, so we need to know op names as well.
213 using DebugInfoTuple =
214     std::tuple<SourceRange, std::string, InlinedCallStackPtr>;
215 constexpr size_t kDebugInfoTupleSourceRangeIndex{0};
216 constexpr size_t kDebugInfoTupleNodeNameIndex{1};
217 constexpr size_t kDebugInfoTupleInlinedCSIndex{2};
218 } // namespace torch::jit
219