xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/ir_dump_util.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/lazy/core/ir_dump_util.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/csrc/lazy/backend/backend_interface.h>
5 #include <torch/csrc/lazy/backend/lowering_context.h>
6 #include <torch/csrc/lazy/core/ir_util.h>
7 #include <optional>
8 
9 #include <regex>
10 #include <sstream>
11 #include <unordered_map>
12 
13 namespace torch {
14 namespace lazy {
15 namespace {
16 
17 using NodeIdMap = std::unordered_map<const Node*, size_t>;
18 
19 struct AttrTag {
20   std::string name;
21   std::string value;
22   std::string::size_type pos = 0;
23 };
24 
SkipTagSeparator(const std::string & node_string,std::string::size_type pos)25 std::string::size_type SkipTagSeparator(
26     const std::string& node_string,
27     std::string::size_type pos) {
28   return node_string.compare(pos, 2, ", ") == 0 ? pos + 2 : pos;
29 }
30 
ParseAttrTag(const std::string & node_string,std::string::size_type pos)31 std::optional<AttrTag> ParseAttrTag(
32     const std::string& node_string,
33     std::string::size_type pos) {
34   // @lint-ignore-every CLANGTIDY facebook-hte-StdRegexIsAwful
35   const std::regex tag_regex("^([a-zA-Z0-9_]+)=");
36   std::smatch match;
37   // @lint-ignore-every CLANGTIDY facebook-hte-StdRegexIsAwful
38   if (!std::regex_search(
39           node_string.begin() + pos, node_string.end(), match, tag_regex)) {
40     return std::nullopt;
41   }
42 
43   std::string::size_type vpos = match[1].second - node_string.begin() + 1;
44   char nested_open = -1;
45   char nested_close = -1;
46   size_t nest_count = 1;
47   AttrTag tag;
48   tag.name = match[1].str();
49   for (pos = vpos; pos < node_string.size(); ++pos) {
50     if (nested_open < 0) {
51       if (SkipTagSeparator(node_string, pos) != pos) {
52         break;
53       }
54       switch (node_string[pos]) {
55         case '(':
56           nested_open = node_string[pos];
57           nested_close = ')';
58           break;
59         case '[':
60           nested_open = node_string[pos];
61           nested_close = ']';
62           break;
63         case '{':
64           nested_open = node_string[pos];
65           nested_close = '}';
66           break;
67       }
68     } else if (node_string[pos] == nested_close) {
69       --nest_count;
70       if (nest_count == 0) {
71         nest_count = 1;
72         nested_open = nested_close = -1;
73       }
74     } else if (node_string[pos] == nested_open) {
75       ++nest_count;
76     }
77   }
78   tag.value = node_string.substr(vpos, pos - vpos);
79   tag.pos = pos;
80   return tag;
81 }
82 
GenerateIdMap(c10::ArrayRef<const Node * > post_order)83 NodeIdMap GenerateIdMap(c10::ArrayRef<const Node*> post_order) {
84   NodeIdMap id_map;
85   for (auto node : post_order) {
86     TORCH_CHECK(id_map.emplace(node, id_map.size()).second, node->ToString());
87   }
88   return id_map;
89 }
90 
GetRootsIds(c10::ArrayRef<const Node * > roots)91 std::unordered_map<const Node*, size_t> GetRootsIds(
92     c10::ArrayRef<const Node*> roots) {
93   std::unordered_map<const Node*, size_t> roots_ids;
94   for (const auto i : c10::irange(roots.size())) {
95     roots_ids[roots[i]] = i;
96   }
97   return roots_ids;
98 }
99 
GetRootNodeId(const Node * node,const std::unordered_map<const Node *,size_t> & roots_ids)100 std::optional<size_t> GetRootNodeId(
101     const Node* node,
102     const std::unordered_map<const Node*, size_t>& roots_ids) {
103   auto it = roots_ids.find(node);
104   if (it == roots_ids.end()) {
105     return std::nullopt;
106   }
107   return it->second;
108 }
109 
GetNodeTags(const Node * node)110 std::vector<AttrTag> GetNodeTags(const Node* node) {
111   std::string node_string = node->ToString();
112   std::string op_string = node->op().ToString();
113   std::string::size_type pos = node_string.find(op_string);
114   TORCH_CHECK(pos != std::string::npos, node_string, " : ", op_string);
115   pos += op_string.size();
116   std::vector<AttrTag> tags;
117   for (;;) {
118     pos = SkipTagSeparator(node_string, pos);
119     auto tag = ParseAttrTag(node_string, pos);
120     if (!tag) {
121       break;
122     }
123     pos = tag->pos;
124     tags.push_back(std::move(*tag));
125   }
126   return tags;
127 }
128 
GenerateDotNodeLabel(const Node * node,const std::unordered_map<const Node *,size_t> & roots_ids)129 std::string GenerateDotNodeLabel(
130     const Node* node,
131     const std::unordered_map<const Node*, size_t>& roots_ids) {
132   static const size_t kMaxValueSize = 64;
133   std::stringstream ss;
134   ss << node->op() << "\\n" << node->shape();
135   for (auto& tag : GetNodeTags(node)) {
136     ss << "\\n" << tag.name << "=";
137     if (tag.value.size() < kMaxValueSize) {
138       ss << tag.value;
139     } else {
140       ss << tag.value.substr(0, kMaxValueSize) << "...";
141     }
142   }
143   auto opt_root_id = GetRootNodeId(node, roots_ids);
144   if (opt_root_id) {
145     ss << "\\nROOT=" << *opt_root_id;
146   }
147   return ss.str();
148 }
149 
GenerateDotNodeSpec(const Node * node,const std::unordered_map<const Node *,size_t> & roots_ids)150 std::string GenerateDotNodeSpec(
151     const Node* node,
152     const std::unordered_map<const Node*, size_t>& roots_ids) {
153   std::stringstream ss;
154   ss << "label=\"" << GenerateDotNodeLabel(node, roots_ids) << "\"";
155   return ss.str();
156 }
157 
GenerateTextNodeSpec(const Node * node,const NodeIdMap & id_map)158 std::string GenerateTextNodeSpec(const Node* node, const NodeIdMap& id_map) {
159   std::stringstream ss;
160   ss << node->shapes() << " " << node->op() << "(";
161   size_t count = 0;
162   for (auto& output : node->operands()) {
163     if (count > 0) {
164       ss << ", ";
165     }
166     ss << "%" << id_map.at(output.node);
167     if (output.node->num_outputs() > 1) {
168       ss << "." << output.index;
169     }
170     ++count;
171   }
172   ss << ")";
173   for (auto& tag : GetNodeTags(node)) {
174     ss << ", " << tag.name << "=" << tag.value;
175   }
176   return ss.str();
177 }
178 
179 } // namespace
180 
ToDot(c10::ArrayRef<const Node * > nodes)181 std::string DumpUtil::ToDot(c10::ArrayRef<const Node*> nodes) {
182   auto post_order = Util::ComputePostOrder(nodes);
183   return PostOrderToDot(post_order, nodes);
184 }
185 
PostOrderToDot(c10::ArrayRef<const Node * > post_order,c10::ArrayRef<const Node * > roots)186 std::string DumpUtil::PostOrderToDot(
187     c10::ArrayRef<const Node*> post_order,
188     c10::ArrayRef<const Node*> roots) {
189   std::unordered_map<const Node*, size_t> roots_ids = GetRootsIds(roots);
190   NodeIdMap id_map = GenerateIdMap(post_order);
191   std::stringstream ss;
192   ss << "digraph G {\n";
193   for (auto node : post_order) {
194     ss << "  node" << id_map.at(node) << " ["
195        << GenerateDotNodeSpec(node, roots_ids) << "]\n";
196   }
197   for (auto it = post_order.rbegin(); it != post_order.rend(); ++it) {
198     const Node* node = *it;
199     size_t id = id_map.at(node);
200     for (const auto i : c10::irange(node->operands().size())) {
201       const Output& output = node->operand(i);
202       ss << "  node" << id_map.at(output.node) << " -> node" << id;
203       if (node->operands().size() > 1) {
204         ss << " [label=\"i=" << i;
205         if (output.node->num_outputs() > 1) {
206           ss << ",o=" << output.index;
207         }
208         ss << "\"]\n";
209       } else {
210         if (output.node->num_outputs() > 1) {
211           ss << " [label=\"o=" << output.index << "\"]";
212         }
213         ss << "\n";
214       }
215     }
216   }
217   ss << "}\n";
218   return ss.str();
219 }
220 
ToText(c10::ArrayRef<const Node * > nodes)221 std::string DumpUtil::ToText(c10::ArrayRef<const Node*> nodes) {
222   auto post_order = Util::ComputePostOrder(nodes);
223   return PostOrderToText(post_order, nodes);
224 }
225 
PostOrderToText(c10::ArrayRef<const Node * > post_order,c10::ArrayRef<const Node * > roots)226 std::string DumpUtil::PostOrderToText(
227     c10::ArrayRef<const Node*> post_order,
228     c10::ArrayRef<const Node*> roots) {
229   std::unordered_map<const Node*, size_t> roots_ids = GetRootsIds(roots);
230   NodeIdMap id_map = GenerateIdMap(post_order);
231   std::stringstream ss;
232   ss << "IR {\n";
233   for (auto node : post_order) {
234     auto opt_root_id = GetRootNodeId(node, roots_ids);
235     ss << "  %" << id_map.at(node) << " = "
236        << GenerateTextNodeSpec(node, id_map);
237     if (opt_root_id) {
238       ss << ", ROOT=" << *opt_root_id;
239     }
240     ss << ", NodeType=" << typeid(*node).name();
241     ss << "\n";
242   }
243   ss << "}\n";
244   return ss.str();
245 }
246 
ToBackend(c10::ArrayRef<Value> values,const BackendDevice & device)247 std::string DumpUtil::ToBackend(
248     c10::ArrayRef<Value> values,
249     const BackendDevice& device) {
250   auto lowering_ctx = LoweringContext::Create("IrToBackend", device);
251   for (auto& ir_value : values) {
252     lowering_ctx->AddResult(ir_value);
253   }
254   auto computation = lowering_ctx->Build();
255   return getBackend()->GetComputationBackendText(computation);
256 }
257 
258 } // namespace lazy
259 } // namespace torch
260