1 #include <torch/csrc/jit/ir/scope.h>
2
3 #include <ATen/core/class_type.h>
4 #include <ATen/core/function.h>
5
6 namespace torch::jit {
7 // util functions
8 namespace utils {
9
get_module_info(const ModuleInstanceInfo & module_instance_info)10 std::string get_module_info(const ModuleInstanceInfo& module_instance_info) {
11 std::string module_info;
12 const auto& class_type = module_instance_info.class_type();
13 std::string instance_name = module_instance_info.instance_name();
14 std::string type_name;
15 if (class_type) {
16 type_name += class_type->name()->qualifiedName();
17 type_name = type_name.substr(type_name.find_last_of('.') + 1);
18 }
19 if (type_name.empty()) {
20 type_name = "UNKNOWN_TYPE";
21 }
22 if (instance_name.empty()) {
23 instance_name = "UNKNOWN_INSTANCE";
24 }
25 module_info.append(instance_name).append("(").append(type_name).append(")");
26 return module_info;
27 }
28
29 } // namespace utils
intrusive_from_this()30 ScopePtr Scope::intrusive_from_this() {
31 c10::raw::intrusive_ptr::incref(this); // we are creating a new pointer
32 // from a raw `this` pointer
33 // so we need to bump the refcount
34 // to account for this ownership
35 return c10::intrusive_ptr<Scope>::reclaim(this);
36 }
37
Scope()38 Scope::Scope() : name_(Symbol::scope("")) {}
39
Scope(ScopePtr parent,Symbol name)40 Scope::Scope(ScopePtr parent, Symbol name)
41 : parent_(std::move(parent)), name_(name) {}
42
push(Symbol name)43 ScopePtr Scope::push(Symbol name) {
44 return c10::make_intrusive<Scope>(intrusive_from_this(), name);
45 }
46
parent()47 ScopePtr Scope::parent() {
48 if (!parent_) {
49 throw std::runtime_error("Cannot get parent from Scope with no parent");
50 }
51 return parent_;
52 }
53
isRoot() const54 bool Scope::isRoot() const {
55 return !parent_;
56 }
57
isBlank() const58 bool Scope::isBlank() const {
59 static const Symbol blank = Symbol::scope("");
60 return isRoot() && name() == blank;
61 }
62
getRoot()63 ScopePtr Scope::getRoot() {
64 ScopePtr current = intrusive_from_this();
65 while (current->parent_) {
66 current = current->parent_;
67 }
68 return current;
69 }
70
getDepth()71 size_t Scope::getDepth() {
72 size_t d = 1;
73 ScopePtr current = intrusive_from_this();
74 while (current->parent_) {
75 current = current->parent_;
76 d += 1;
77 }
78 return d;
79 }
80
name() const81 Symbol Scope::name() const {
82 return name_;
83 }
84
namesFromRoot(const std::string & separator) const85 std::string Scope::namesFromRoot(const std::string& separator) const {
86 // TODO: I think the answer is we shouldn't have used Symbol here
87 std::string out = this->name_.toUnqualString();
88 if (this->isRoot()) {
89 return out;
90 }
91 ScopePtr parent = this->parent_;
92 while (!parent->isRoot()) {
93 // NOLINTNEXTLINE(performance-inefficient-string-concatenation)
94 out = std::string(parent->name_.toUnqualString()) + separator + out;
95 parent = parent->parent_;
96 }
97 return out;
98 }
99
intrusive_from_this()100 InlinedCallStackPtr InlinedCallStack::intrusive_from_this() {
101 c10::raw::intrusive_ptr::incref(this); // we are creating a new pointer
102 // from a raw `this` pointer
103 // so we need to bump the refcount
104 // to account for this ownership
105 return c10::intrusive_ptr<InlinedCallStack>::reclaim(this);
106 }
107
InlinedCallStack(Function * fn,SourceRange source_range)108 InlinedCallStack::InlinedCallStack(Function* fn, SourceRange source_range)
109 : fn_(fn),
110 fn_name_(fn_ ? fn_->name() : ""),
111 source_range_(std::move(source_range)) {}
112
InlinedCallStack(Function * fn,SourceRange source_range,std::optional<ModuleInstanceInfo> module_instance_info)113 InlinedCallStack::InlinedCallStack(
114 Function* fn,
115 SourceRange source_range,
116 std::optional<ModuleInstanceInfo> module_instance_info)
117 : fn_(fn),
118 fn_name_(fn_ ? fn_->name() : ""),
119 source_range_(std::move(source_range)),
120 module_instance_info_(std::move(module_instance_info)) {}
121
InlinedCallStack(Function * fn,SourceRange source_range,std::optional<ModuleInstanceInfo> module_instance_info,std::string & function_name)122 InlinedCallStack::InlinedCallStack(
123 Function* fn,
124 SourceRange source_range,
125 std::optional<ModuleInstanceInfo> module_instance_info,
126 std::string& function_name)
127 : fn_(fn),
128 fn_name_(std::move(function_name)),
129 source_range_(std::move(source_range)),
130 module_instance_info_(std::move(module_instance_info)) {}
131
InlinedCallStack(InlinedCallStackPtr callee,Function * fn,SourceRange source_range)132 InlinedCallStack::InlinedCallStack(
133 InlinedCallStackPtr callee,
134 Function* fn,
135 SourceRange source_range)
136 : callee_(std::move(callee)),
137 fn_(fn),
138 fn_name_(fn_ ? fn_->name() : ""),
139 source_range_(std::move(source_range)) {}
140
InlinedCallStack(InlinedCallStackPtr callee,Function * fn,SourceRange source_range,std::optional<ModuleInstanceInfo> module_instance_info,std::string & function_name)141 InlinedCallStack::InlinedCallStack(
142 InlinedCallStackPtr callee,
143 Function* fn,
144 SourceRange source_range,
145 std::optional<ModuleInstanceInfo> module_instance_info,
146 std::string& function_name)
147 : callee_(std::move(callee)),
148 fn_(fn),
149 fn_name_(std::move(function_name)),
150 source_range_(std::move(source_range)),
151 module_instance_info_(std::move(module_instance_info)) {}
152
InlinedCallStack(InlinedCallStackPtr callee,Function * fn,SourceRange source_range,std::optional<ModuleInstanceInfo> module_instance_info)153 InlinedCallStack::InlinedCallStack(
154 InlinedCallStackPtr callee,
155 Function* fn,
156 SourceRange source_range,
157 std::optional<ModuleInstanceInfo> module_instance_info)
158 : callee_(std::move(callee)),
159 fn_(fn),
160 fn_name_(fn_ ? fn_->name() : ""),
161 source_range_(std::move(source_range)),
162 module_instance_info_(std::move(module_instance_info)) {}
163
callee() const164 std::optional<InlinedCallStackPtr> InlinedCallStack::callee() const {
165 return callee_;
166 }
167
setCallee(std::optional<InlinedCallStackPtr> callee)168 void InlinedCallStack::setCallee(std::optional<InlinedCallStackPtr> callee) {
169 callee_ = std::move(callee);
170 }
171
module_instance() const172 std::optional<ModuleInstanceInfo> InlinedCallStack::module_instance() const {
173 return module_instance_info_;
174 }
175
source_range() const176 SourceRange InlinedCallStack::source_range() const {
177 return source_range_;
178 }
179
function() const180 Function* InlinedCallStack::function() const {
181 return fn_;
182 }
183
function_name() const184 const std::string& InlinedCallStack::function_name() const {
185 return fn_name_;
186 }
187
vec()188 std::vector<InlinedCallStackEntry> InlinedCallStack::vec() {
189 std::vector<InlinedCallStackEntry> r;
190 std::optional<InlinedCallStackPtr> current = intrusive_from_this();
191 while (current) {
192 r.emplace_back(
193 (*current)->fn_,
194 (*current)->source_range_,
195 (*current)->module_instance_info_);
196 current = (*current)->callee_;
197 }
198 return r;
199 }
200
ModuleInstanceInfo(c10::ClassTypePtr module_type,std::string instance_name)201 ModuleInstanceInfo::ModuleInstanceInfo(
202 c10::ClassTypePtr module_type,
203 std::string instance_name)
204 : module_type_(std::move(module_type)),
205 instance_name_(std::move(instance_name)) {}
206 } // namespace torch::jit
207