xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_graph_iterator.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <iostream>
2 #include <sstream>
3 #include <string>
4 
5 #include <gtest/gtest.h>
6 
7 #include <test/cpp/jit/test_utils.h>
8 #include <torch/csrc/jit/ir/irparser.h>
9 #include <torch/csrc/jit/runtime/graph_iterator.h>
10 #include <torch/jit.h>
11 #include <torch/script.h>
12 #include <torch/torch.h>
13 
14 namespace torch {
15 namespace jit {
16 
17 /**
18  * Inverts an unordered map.
19  */
20 template <typename K, typename V>
invert_map(std::unordered_map<K,V> & map)21 std::unordered_map<V, K> invert_map(std::unordered_map<K, V>& map) {
22   std::unordered_map<V, K> inverted;
23   std::for_each(map.begin(), map.end(), [&inverted](const std::pair<K, V>& p) {
24     inverted.insert(std::make_pair(p.second, p.first));
25   });
26   return inverted;
27 }
28 
29 /**
30  * Traverses the graph using the DepthFirstGraphNodeIterator and
31  * returns an array containing the original names in the string
32  * graph.
33  */
traverse_depth_first(std::string graph_string,int max_count=100)34 std::vector<std::string> traverse_depth_first(
35     std::string graph_string,
36     int max_count = 100) {
37   auto graph = std::make_shared<Graph>();
38   std::unordered_map<std::string, Value*> vmap;
39   torch::jit::parseIR(graph_string, graph.get(), vmap);
40   auto get_name = invert_map(vmap);
41 
42   std::vector<std::string> result;
43   DepthFirstGraphNodeIterator graph_it(graph);
44   Node* node = graph_it.next();
45   int count = 0;
46   while (node && count < max_count) {
47     std::stringstream buffer;
48     std::vector<const torch::jit::Node*> vec;
49     node->print(buffer, 0, &vec, false, true, true, false);
50     result.push_back(buffer.str());
51     node = graph_it.next();
52     ++count;
53   }
54   return result;
55 }
56 
57 /** Checks that the iteration order matches the expected/provided order. */
assert_ordering(std::vector<std::string> actual,std::initializer_list<std::string> expected_list)58 void assert_ordering(
59     std::vector<std::string> actual,
60     std::initializer_list<std::string> expected_list) {
61   auto expected = std::vector<std::string>(expected_list);
62   ASSERT_EQ(expected.size(), actual.size())
63       << "Got " << actual.size() << " elements (" << actual << ")"
64       << " expected " << expected.size() << " elements (" << expected << ")";
65   for (unsigned i = 0; i < expected.size(); i++) {
66     ASSERT_EQ(expected[i], actual[i])
67         << "Difference at index " << i << " in " << actual << " (expected "
68         << actual << ")";
69   }
70 }
71 
TEST(GraphIteratorTest,ConstantReturnGraph)72 TEST(GraphIteratorTest, ConstantReturnGraph) {
73   const auto graph_string = R"IR(
74       graph():
75         %1 : int = prim::Constant[value=0]()
76         return (%1))IR";
77   auto graph = std::make_shared<Graph>();
78   torch::jit::parseIR(graph_string, graph.get());
79   DepthFirstGraphNodeIterator graph_it(graph);
80   ASSERT_EQ(graph_it.next()->kind(), prim::Constant);
81   ASSERT_EQ(graph_it.next(), nullptr);
82 }
83 
TEST(GraphIteratorTest,GraphWithParameters)84 TEST(GraphIteratorTest, GraphWithParameters) {
85   const auto graph_string = R"IR(
86       graph(%0 : Double(2)):
87         %1 : int = prim::Constant[value=0]()
88         return (%0))IR";
89   auto ordering = traverse_depth_first(graph_string);
90   assert_ordering(ordering, {"%1 : int = prim::Constant[value=0]()"});
91 }
92 
TEST(GraphIteratorTest,GraphWithIf)93 TEST(GraphIteratorTest, GraphWithIf) {
94   const auto graph_string = R"IR(
95 graph(%a : Tensor):
96   %a : int = prim::Constant[value=30]()
97   %b : int = prim::Constant[value=10]()
98   %c : bool = aten::Bool(%a)
99   %d : int = prim::If(%c)
100     block0():
101       -> (%a)
102     block1():
103       -> (%b)
104   %e : int = prim::Constant[value=20]()
105   return (%d)
106 )IR";
107   auto ordering = traverse_depth_first(graph_string);
108   assert_ordering(
109       ordering,
110       {"%1 : int = prim::Constant[value=30]()",
111        "%2 : int = prim::Constant[value=10]()",
112        "%3 : bool = aten::Bool(%1)",
113        "%4 : int = prim::If(%3)",
114        "%5 : int = prim::Constant[value=20]()"});
115 }
116 
TEST(GraphIteratorTest,GraphWithNestedIf)117 TEST(GraphIteratorTest, GraphWithNestedIf) {
118   const auto graph_string = R"IR(
119 graph(%a.1 : Tensor,
120       %b.1 : Tensor):
121   %2 : int = prim::Constant[value=10]()
122   %3 : int = prim::Constant[value=20]()
123   %4 : int = prim::Constant[value=30]()
124   %5 : int = prim::Constant[value=40]()
125   %6 : bool = aten::Bool(%a.1)
126   %7 : int = prim::If(%6)
127     block0():
128       %8 : bool = aten::Bool(%b.1)
129       %9 : int = prim::If(%8)
130         block0():
131           -> (%2)
132         block1():
133           -> (%3)
134       -> (%9)
135     block1():
136       %10 : bool = aten::Bool(%b.1)
137       %11 : int = prim::If(%10)
138         block0():
139           -> (%4)
140         block1():
141           -> (%5)
142       -> (%11)
143   %8 : bool = aten::Bool(%b.1)
144   %9 : int = prim::If(%8)
145     block0():
146       -> (%2)
147     block1():
148       -> (%3)
149   %10 : bool = aten::Bool(%b.1)
150   %11 : int = prim::If(%10)
151     block0():
152       -> (%4)
153     block1():
154       -> (%5)
155   return (%7)
156 )IR";
157   auto ordering = traverse_depth_first(graph_string);
158   assert_ordering(
159       ordering,
160       {"%2 : int = prim::Constant[value=10]()",
161        "%3 : int = prim::Constant[value=20]()",
162        "%4 : int = prim::Constant[value=30]()",
163        "%5 : int = prim::Constant[value=40]()",
164        "%6 : bool = aten::Bool(%a.1)",
165        "%7 : int = prim::If(%6)",
166        "%8 : bool = aten::Bool(%b.1)",
167        "%9 : int = prim::If(%8)",
168        "%10 : bool = aten::Bool(%b.1)",
169        "%11 : int = prim::If(%10)",
170        "%12 : bool = aten::Bool(%b.1)",
171        "%13 : int = prim::If(%12)",
172        "%14 : bool = aten::Bool(%b.1)",
173        "%15 : int = prim::If(%14)"});
174 }
175 
TEST(GraphIteratorTest,GraphWithLoop)176 TEST(GraphIteratorTest, GraphWithLoop) {
177   const auto graph_string = R"IR(
178 graph(%a.1 : Tensor):
179   %1 : bool = prim::Constant[value=1]()
180   %2 : int = prim::Constant[value=10]()
181   %3 : int = prim::Constant[value=1]()
182   %4 : Tensor = prim::Loop(%2, %1, %a.1)
183     block0(%i : int, %b.9 : Tensor):
184       %5 : Tensor = aten::add_(%b.9, %3, %3)
185       -> (%1, %5)
186   %6 : Tensor = prim::Loop(%2, %1, %a.1)
187     block0(%i : int, %b.9 : Tensor):
188       -> (%1, %4)
189   return (%6)
190 )IR";
191   auto ordering = traverse_depth_first(graph_string);
192   assert_ordering(
193       ordering,
194       {"%1 : bool = prim::Constant[value=1]()",
195        "%2 : int = prim::Constant[value=10]()",
196        "%3 : int = prim::Constant[value=1]()",
197        "%4 : Tensor = prim::Loop(%2, %1, %a.1)",
198        "%7 : Tensor = aten::add_(%b.10, %3, %3)",
199        "%8 : Tensor = prim::Loop(%2, %1, %a.1)"});
200 }
201 
202 } // namespace jit
203 } // namespace torch
204