1 #include <torch/csrc/lazy/core/ir_dump_util.h>
2
3 #include <c10/util/irange.h>
4 #include <torch/csrc/lazy/backend/backend_interface.h>
5 #include <torch/csrc/lazy/backend/lowering_context.h>
6 #include <torch/csrc/lazy/core/ir_util.h>
7 #include <optional>
8
9 #include <regex>
10 #include <sstream>
11 #include <unordered_map>
12
13 namespace torch {
14 namespace lazy {
15 namespace {
16
17 using NodeIdMap = std::unordered_map<const Node*, size_t>;
18
19 struct AttrTag {
20 std::string name;
21 std::string value;
22 std::string::size_type pos = 0;
23 };
24
SkipTagSeparator(const std::string & node_string,std::string::size_type pos)25 std::string::size_type SkipTagSeparator(
26 const std::string& node_string,
27 std::string::size_type pos) {
28 return node_string.compare(pos, 2, ", ") == 0 ? pos + 2 : pos;
29 }
30
ParseAttrTag(const std::string & node_string,std::string::size_type pos)31 std::optional<AttrTag> ParseAttrTag(
32 const std::string& node_string,
33 std::string::size_type pos) {
34 // @lint-ignore-every CLANGTIDY facebook-hte-StdRegexIsAwful
35 const std::regex tag_regex("^([a-zA-Z0-9_]+)=");
36 std::smatch match;
37 // @lint-ignore-every CLANGTIDY facebook-hte-StdRegexIsAwful
38 if (!std::regex_search(
39 node_string.begin() + pos, node_string.end(), match, tag_regex)) {
40 return std::nullopt;
41 }
42
43 std::string::size_type vpos = match[1].second - node_string.begin() + 1;
44 char nested_open = -1;
45 char nested_close = -1;
46 size_t nest_count = 1;
47 AttrTag tag;
48 tag.name = match[1].str();
49 for (pos = vpos; pos < node_string.size(); ++pos) {
50 if (nested_open < 0) {
51 if (SkipTagSeparator(node_string, pos) != pos) {
52 break;
53 }
54 switch (node_string[pos]) {
55 case '(':
56 nested_open = node_string[pos];
57 nested_close = ')';
58 break;
59 case '[':
60 nested_open = node_string[pos];
61 nested_close = ']';
62 break;
63 case '{':
64 nested_open = node_string[pos];
65 nested_close = '}';
66 break;
67 }
68 } else if (node_string[pos] == nested_close) {
69 --nest_count;
70 if (nest_count == 0) {
71 nest_count = 1;
72 nested_open = nested_close = -1;
73 }
74 } else if (node_string[pos] == nested_open) {
75 ++nest_count;
76 }
77 }
78 tag.value = node_string.substr(vpos, pos - vpos);
79 tag.pos = pos;
80 return tag;
81 }
82
GenerateIdMap(c10::ArrayRef<const Node * > post_order)83 NodeIdMap GenerateIdMap(c10::ArrayRef<const Node*> post_order) {
84 NodeIdMap id_map;
85 for (auto node : post_order) {
86 TORCH_CHECK(id_map.emplace(node, id_map.size()).second, node->ToString());
87 }
88 return id_map;
89 }
90
GetRootsIds(c10::ArrayRef<const Node * > roots)91 std::unordered_map<const Node*, size_t> GetRootsIds(
92 c10::ArrayRef<const Node*> roots) {
93 std::unordered_map<const Node*, size_t> roots_ids;
94 for (const auto i : c10::irange(roots.size())) {
95 roots_ids[roots[i]] = i;
96 }
97 return roots_ids;
98 }
99
GetRootNodeId(const Node * node,const std::unordered_map<const Node *,size_t> & roots_ids)100 std::optional<size_t> GetRootNodeId(
101 const Node* node,
102 const std::unordered_map<const Node*, size_t>& roots_ids) {
103 auto it = roots_ids.find(node);
104 if (it == roots_ids.end()) {
105 return std::nullopt;
106 }
107 return it->second;
108 }
109
GetNodeTags(const Node * node)110 std::vector<AttrTag> GetNodeTags(const Node* node) {
111 std::string node_string = node->ToString();
112 std::string op_string = node->op().ToString();
113 std::string::size_type pos = node_string.find(op_string);
114 TORCH_CHECK(pos != std::string::npos, node_string, " : ", op_string);
115 pos += op_string.size();
116 std::vector<AttrTag> tags;
117 for (;;) {
118 pos = SkipTagSeparator(node_string, pos);
119 auto tag = ParseAttrTag(node_string, pos);
120 if (!tag) {
121 break;
122 }
123 pos = tag->pos;
124 tags.push_back(std::move(*tag));
125 }
126 return tags;
127 }
128
GenerateDotNodeLabel(const Node * node,const std::unordered_map<const Node *,size_t> & roots_ids)129 std::string GenerateDotNodeLabel(
130 const Node* node,
131 const std::unordered_map<const Node*, size_t>& roots_ids) {
132 static const size_t kMaxValueSize = 64;
133 std::stringstream ss;
134 ss << node->op() << "\\n" << node->shape();
135 for (auto& tag : GetNodeTags(node)) {
136 ss << "\\n" << tag.name << "=";
137 if (tag.value.size() < kMaxValueSize) {
138 ss << tag.value;
139 } else {
140 ss << tag.value.substr(0, kMaxValueSize) << "...";
141 }
142 }
143 auto opt_root_id = GetRootNodeId(node, roots_ids);
144 if (opt_root_id) {
145 ss << "\\nROOT=" << *opt_root_id;
146 }
147 return ss.str();
148 }
149
GenerateDotNodeSpec(const Node * node,const std::unordered_map<const Node *,size_t> & roots_ids)150 std::string GenerateDotNodeSpec(
151 const Node* node,
152 const std::unordered_map<const Node*, size_t>& roots_ids) {
153 std::stringstream ss;
154 ss << "label=\"" << GenerateDotNodeLabel(node, roots_ids) << "\"";
155 return ss.str();
156 }
157
GenerateTextNodeSpec(const Node * node,const NodeIdMap & id_map)158 std::string GenerateTextNodeSpec(const Node* node, const NodeIdMap& id_map) {
159 std::stringstream ss;
160 ss << node->shapes() << " " << node->op() << "(";
161 size_t count = 0;
162 for (auto& output : node->operands()) {
163 if (count > 0) {
164 ss << ", ";
165 }
166 ss << "%" << id_map.at(output.node);
167 if (output.node->num_outputs() > 1) {
168 ss << "." << output.index;
169 }
170 ++count;
171 }
172 ss << ")";
173 for (auto& tag : GetNodeTags(node)) {
174 ss << ", " << tag.name << "=" << tag.value;
175 }
176 return ss.str();
177 }
178
179 } // namespace
180
ToDot(c10::ArrayRef<const Node * > nodes)181 std::string DumpUtil::ToDot(c10::ArrayRef<const Node*> nodes) {
182 auto post_order = Util::ComputePostOrder(nodes);
183 return PostOrderToDot(post_order, nodes);
184 }
185
PostOrderToDot(c10::ArrayRef<const Node * > post_order,c10::ArrayRef<const Node * > roots)186 std::string DumpUtil::PostOrderToDot(
187 c10::ArrayRef<const Node*> post_order,
188 c10::ArrayRef<const Node*> roots) {
189 std::unordered_map<const Node*, size_t> roots_ids = GetRootsIds(roots);
190 NodeIdMap id_map = GenerateIdMap(post_order);
191 std::stringstream ss;
192 ss << "digraph G {\n";
193 for (auto node : post_order) {
194 ss << " node" << id_map.at(node) << " ["
195 << GenerateDotNodeSpec(node, roots_ids) << "]\n";
196 }
197 for (auto it = post_order.rbegin(); it != post_order.rend(); ++it) {
198 const Node* node = *it;
199 size_t id = id_map.at(node);
200 for (const auto i : c10::irange(node->operands().size())) {
201 const Output& output = node->operand(i);
202 ss << " node" << id_map.at(output.node) << " -> node" << id;
203 if (node->operands().size() > 1) {
204 ss << " [label=\"i=" << i;
205 if (output.node->num_outputs() > 1) {
206 ss << ",o=" << output.index;
207 }
208 ss << "\"]\n";
209 } else {
210 if (output.node->num_outputs() > 1) {
211 ss << " [label=\"o=" << output.index << "\"]";
212 }
213 ss << "\n";
214 }
215 }
216 }
217 ss << "}\n";
218 return ss.str();
219 }
220
ToText(c10::ArrayRef<const Node * > nodes)221 std::string DumpUtil::ToText(c10::ArrayRef<const Node*> nodes) {
222 auto post_order = Util::ComputePostOrder(nodes);
223 return PostOrderToText(post_order, nodes);
224 }
225
PostOrderToText(c10::ArrayRef<const Node * > post_order,c10::ArrayRef<const Node * > roots)226 std::string DumpUtil::PostOrderToText(
227 c10::ArrayRef<const Node*> post_order,
228 c10::ArrayRef<const Node*> roots) {
229 std::unordered_map<const Node*, size_t> roots_ids = GetRootsIds(roots);
230 NodeIdMap id_map = GenerateIdMap(post_order);
231 std::stringstream ss;
232 ss << "IR {\n";
233 for (auto node : post_order) {
234 auto opt_root_id = GetRootNodeId(node, roots_ids);
235 ss << " %" << id_map.at(node) << " = "
236 << GenerateTextNodeSpec(node, id_map);
237 if (opt_root_id) {
238 ss << ", ROOT=" << *opt_root_id;
239 }
240 ss << ", NodeType=" << typeid(*node).name();
241 ss << "\n";
242 }
243 ss << "}\n";
244 return ss.str();
245 }
246
ToBackend(c10::ArrayRef<Value> values,const BackendDevice & device)247 std::string DumpUtil::ToBackend(
248 c10::ArrayRef<Value> values,
249 const BackendDevice& device) {
250 auto lowering_ctx = LoweringContext::Create("IrToBackend", device);
251 for (auto& ir_value : values) {
252 lowering_ctx->AddResult(ir_value);
253 }
254 auto computation = lowering_ctx->Build();
255 return getBackend()->GetComputationBackendText(computation);
256 }
257
258 } // namespace lazy
259 } // namespace torch
260