xref: /aosp_15_r20/external/pytorch/test/package/test_digraph.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: package/deploy"]
2
3from torch.package._digraph import DiGraph
4from torch.testing._internal.common_utils import run_tests
5
6
7try:
8    from .common import PackageTestCase
9except ImportError:
10    # Support the case where we run this file directly.
11    from common import PackageTestCase
12
13
14class TestDiGraph(PackageTestCase):
15    """Test the DiGraph structure we use to represent dependencies in PackageExporter"""
16
17    def test_successors(self):
18        g = DiGraph()
19        g.add_edge("foo", "bar")
20        g.add_edge("foo", "baz")
21        g.add_node("qux")
22
23        self.assertIn("bar", list(g.successors("foo")))
24        self.assertIn("baz", list(g.successors("foo")))
25        self.assertEqual(len(list(g.successors("qux"))), 0)
26
27    def test_predecessors(self):
28        g = DiGraph()
29        g.add_edge("foo", "bar")
30        g.add_edge("foo", "baz")
31        g.add_node("qux")
32
33        self.assertIn("foo", list(g.predecessors("bar")))
34        self.assertIn("foo", list(g.predecessors("baz")))
35        self.assertEqual(len(list(g.predecessors("qux"))), 0)
36
37    def test_successor_not_in_graph(self):
38        g = DiGraph()
39        with self.assertRaises(ValueError):
40            g.successors("not in graph")
41
42    def test_predecessor_not_in_graph(self):
43        g = DiGraph()
44        with self.assertRaises(ValueError):
45            g.predecessors("not in graph")
46
47    def test_node_attrs(self):
48        g = DiGraph()
49        g.add_node("foo", my_attr=1, other_attr=2)
50        self.assertEqual(g.nodes["foo"]["my_attr"], 1)
51        self.assertEqual(g.nodes["foo"]["other_attr"], 2)
52
53    def test_node_attr_update(self):
54        g = DiGraph()
55        g.add_node("foo", my_attr=1)
56        self.assertEqual(g.nodes["foo"]["my_attr"], 1)
57
58        g.add_node("foo", my_attr="different")
59        self.assertEqual(g.nodes["foo"]["my_attr"], "different")
60
61    def test_edges(self):
62        g = DiGraph()
63        g.add_edge(1, 2)
64        g.add_edge(2, 3)
65        g.add_edge(1, 3)
66        g.add_edge(4, 5)
67
68        edge_list = list(g.edges)
69        self.assertEqual(len(edge_list), 4)
70
71        self.assertIn((1, 2), edge_list)
72        self.assertIn((2, 3), edge_list)
73        self.assertIn((1, 3), edge_list)
74        self.assertIn((4, 5), edge_list)
75
76    def test_iter(self):
77        g = DiGraph()
78        g.add_node(1)
79        g.add_node(2)
80        g.add_node(3)
81
82        nodes = set()
83        nodes.update(g)
84
85        self.assertEqual(nodes, {1, 2, 3})
86
87    def test_contains(self):
88        g = DiGraph()
89        g.add_node("yup")
90
91        self.assertTrue("yup" in g)
92        self.assertFalse("nup" in g)
93
94    def test_contains_non_hashable(self):
95        g = DiGraph()
96        self.assertFalse([1, 2, 3] in g)
97
98    def test_forward_closure(self):
99        g = DiGraph()
100        g.add_edge("1", "2")
101        g.add_edge("2", "3")
102        g.add_edge("5", "4")
103        g.add_edge("4", "3")
104        self.assertTrue(g.forward_transitive_closure("1") == {"1", "2", "3"})
105        self.assertTrue(g.forward_transitive_closure("4") == {"4", "3"})
106
107    def test_all_paths(self):
108        g = DiGraph()
109        g.add_edge("1", "2")
110        g.add_edge("1", "7")
111        g.add_edge("7", "8")
112        g.add_edge("8", "3")
113        g.add_edge("2", "3")
114        g.add_edge("5", "4")
115        g.add_edge("4", "3")
116
117        result = g.all_paths("1", "3")
118        # to get rid of indeterminism
119        actual = {i.strip("\n") for i in result.split(";")[2:-1]}
120        expected = {
121            '"2" -> "3"',
122            '"1" -> "7"',
123            '"7" -> "8"',
124            '"1" -> "2"',
125            '"8" -> "3"',
126        }
127        self.assertEqual(actual, expected)
128
129
130if __name__ == "__main__":
131    run_tests()
132