xref: /aosp_15_r20/external/pytorch/test/distributed/checkpoint/test_traverse.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3from collections import OrderedDict
4from typing import TYPE_CHECKING
5
6import torch
7import torch.distributed.checkpoint._traverse as _traverse
8from torch.testing._internal.common_utils import run_tests, TestCase
9
10
11if TYPE_CHECKING:
12    from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
13
14
15# TODO: add comments for TestTraverse
16class TestTraverse(TestCase):
17    def test_traverse_shallow(self) -> None:
18        state_dict = {
19            "key0": 1,
20            "key1": [1, 2],
21            "key2": {1: 2, 2: 3},
22            "key3": torch.tensor([1]),
23        }
24
25        data = {}
26
27        def collect_data(path, value):
28            nonlocal data
29            data[path] = value
30
31        _traverse.traverse_state_dict(state_dict, collect_data)
32
33        self.assertIn(("key0",), data)
34        self.assertEqual(data[("key0",)], 1)
35
36        self.assertIn(("key1",), data)
37        self.assertEqual(data[("key1",)], [1, 2])
38
39        self.assertIn(("key2", "1"), data)
40        self.assertEqual(data[("key2", "1")], 2)
41        self.assertIn(("key2", "2"), data)
42        self.assertEqual(data[("key2", "2")], 3)
43
44        self.assertIn(("key3",), data)
45        self.assertEqual(data[("key3",)], torch.tensor([1]))
46
47    def test_traverse_nested_list(self) -> None:
48        state_dict = {
49            "key1": [
50                torch.tensor([1]),
51                [33, torch.tensor([2]), [44, 55]],
52                [66, 77],
53            ],
54        }
55
56        data = {}
57
58        def collect_data(path, value):
59            nonlocal data
60            data[path] = value
61
62        _traverse.traverse_state_dict(state_dict, collect_data)
63
64        self.assertNotIn(("key1"), data)
65
66        self.assertIn(("key1", 0), data)
67        self.assertEqual(data[("key1", 0)], torch.tensor([1]))
68
69        self.assertIn(("key1", 1, 0), data)
70        self.assertEqual(data[("key1", 1, 0)], 33)
71
72        self.assertIn(("key1", 1, 1), data)
73        self.assertEqual(data[("key1", 1, 1)], torch.tensor([2]))
74
75        self.assertIn(("key1", 1, 2), data)
76        self.assertEqual(data[("key1", 1, 2)], [44, 55])
77        self.assertNotIn(("key1", 1, 2, 0), data)
78
79        self.assertIn(("key1", 2), data)
80        self.assertEqual(data[("key1", 2)], [66, 77])
81
82    def test_traverse_nested_dict(self) -> None:
83        state_dict = {
84            "key0": {"key1": 99, "key2": torch.tensor([1])},
85        }
86
87        data = {}
88
89        def collect_data(path, value):
90            nonlocal data
91            data[path] = value
92
93        _traverse.traverse_state_dict(state_dict, collect_data)
94
95        self.assertNotIn(("key0",), data)
96
97        self.assertIn(("key0", "key1"), data)
98        self.assertEqual(data[("key0", "key1")], 99)
99
100        self.assertIn(("key0", "key2"), data)
101        self.assertEqual(data[("key0", "key2")], torch.tensor([1]))
102
103    def test_traverse_doesnt_ignore_intermediate_collections(self) -> None:
104        state_dict: STATE_DICT_TYPE = {"key0": [{"key1": {"key2": torch.tensor([1])}}]}
105
106        data = {}
107
108        def collect_data(path, value):
109            nonlocal data
110            data[path] = value
111
112        _traverse.traverse_state_dict(state_dict, collect_data)
113
114        self.assertIn(("key0", 0, "key1", "key2"), data)
115        self.assertEqual(
116            data[("key0", 0, "key1", "key2")],
117            torch.tensor([1]),
118        )
119
120    def test_traverse_with_ordered_dict(self) -> None:
121        state_dict = OrderedDict(
122            {
123                "key0": [
124                    99,
125                    torch.tensor([3]),
126                ]
127            }
128        )
129
130        data = {}
131
132        def collect_data(path, value):
133            nonlocal data
134            data[path] = value
135
136        _traverse.traverse_state_dict(state_dict, collect_data)
137
138        self.assertIn(("key0", 0), data)
139        self.assertEqual(data[("key0", 0)], 99)
140
141        self.assertIn(("key0", 1), data)
142        self.assertEqual(data[("key0", 1)], torch.tensor([3]))
143
144    def test_set_element(self) -> None:
145        state_dict: STATE_DICT_TYPE = {}
146
147        _traverse.set_element(state_dict, ("k",), 10)
148        self.assertEqual(state_dict["k"], 10)
149
150        _traverse.set_element(state_dict, ("k1", 2), 1)
151        self.assertEqual(state_dict["k1"], [None, None, 1])
152
153        _traverse.set_element(state_dict, ("k1", 1), 99)
154        self.assertEqual(state_dict["k1"], [None, 99, 1])
155
156        _traverse.set_element(state_dict, ("k1", 3), 88)
157        self.assertEqual(state_dict["k1"], [None, 99, 1, 88])
158
159        _traverse.set_element(state_dict, ("k2", "k3"), 3)
160        self.assertEqual(state_dict["k2"], {"k3": 3})
161
162        _traverse.set_element(state_dict, ("k2", "k4", 0, 0), 99)
163        self.assertEqual(state_dict["k2"]["k4"][0], [99])
164
165    def test_get_element(self) -> None:
166        state_dict = {"a": [0, 1], "b": [2, {"c": "d"}]}
167        self.assertEqual(_traverse.get_element(state_dict, ("a",)), [0, 1])
168        self.assertEqual(_traverse.get_element(state_dict, ("b", 0)), 2)
169        self.assertEqual(_traverse.get_element(state_dict, ("b", 1, "c")), "d")
170
171        self.assertIsNone(_traverse.get_element(state_dict, ("c",)))
172        self.assertIsNone(_traverse.get_element(state_dict, ("a", 33)))
173        self.assertIsNone(_traverse.get_element(state_dict, ("b", 88)))
174        self.assertIsNone(_traverse.get_element(state_dict, ("b", 0, 2)))
175        self.assertIsNone(_traverse.get_element(state_dict, ("b", 1, 2)))
176        self.assertIsNone(_traverse.get_element(state_dict, ("b", 1, "d")))
177
178
179if __name__ == "__main__":
180    run_tests()
181