xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/ir_util.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/lazy/core/ir_util.h>
2 
3 #include <c10/util/Logging.h>
4 
5 namespace torch {
6 namespace lazy {
7 
ComputePostOrder(const Node * node,EmissionMap * emap)8 std::vector<const Node*> Util::ComputePostOrder(
9     const Node* node,
10     EmissionMap* emap) {
11   std::vector<const Node*> post_order;
12   std::vector<const Node*> queue;
13   queue.push_back(node);
14   while (!queue.empty()) {
15     node = queue.back();
16     auto it = emap->find(node);
17     if (it == emap->end()) {
18       (*emap)[node] = kEmitting;
19       for (auto& output : node->operands()) {
20         auto oit = emap->find(output.node);
21         if (oit == emap->end()) {
22           queue.push_back(output.node);
23         } else {
24           TORCH_CHECK(
25               oit->second != kEmitting,
26               "Graph loop found at ",
27               output.node->ToString());
28         }
29       }
30     } else if (it->second == kEmitting) {
31       for (auto& output : node->operands()) {
32         auto oit = emap->find(output.node);
33         TORCH_CHECK(
34             oit != emap->end() && oit->second == kEmitted,
35             "Graph loop found at ",
36             output.node->ToString());
37       }
38       (*emap)[node] = kEmitted;
39       post_order.push_back(node);
40       queue.pop_back();
41     } else {
42       TORCH_CHECK(it->second == kEmitted);
43       queue.pop_back();
44     }
45   }
46   return post_order;
47 }
48 
ComputePostOrder(c10::ArrayRef<const Node * > nodes,EmissionMap * emap)49 std::vector<const Node*> Util::ComputePostOrder(
50     c10::ArrayRef<const Node*> nodes,
51     EmissionMap* emap) {
52   std::vector<const Node*> post_order;
53   for (auto node : nodes) {
54     auto node_post_order = ComputePostOrder(node, emap);
55     post_order.insert(
56         post_order.end(), node_post_order.begin(), node_post_order.end());
57   }
58   return post_order;
59 }
60 
ComputePostOrder(c10::ArrayRef<const Node * > nodes)61 std::vector<const Node*> Util::ComputePostOrder(
62     c10::ArrayRef<const Node*> nodes) {
63   EmissionMap emap;
64   return ComputePostOrder(nodes, &emap);
65 }
66 
GetGraphSize(c10::ArrayRef<const Node * > nodes)67 size_t Util::GetGraphSize(c10::ArrayRef<const Node*> nodes) {
68   return ComputePostOrder(nodes).size();
69 }
70 
71 } // namespace lazy
72 } // namespace torch
73