1# Owner(s): ["oncall: distributed"] 2 3from collections import OrderedDict 4from typing import TYPE_CHECKING 5 6import torch 7import torch.distributed.checkpoint._traverse as _traverse 8from torch.testing._internal.common_utils import run_tests, TestCase 9 10 11if TYPE_CHECKING: 12 from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE 13 14 15# TODO: add comments for TestTraverse 16class TestTraverse(TestCase): 17 def test_traverse_shallow(self) -> None: 18 state_dict = { 19 "key0": 1, 20 "key1": [1, 2], 21 "key2": {1: 2, 2: 3}, 22 "key3": torch.tensor([1]), 23 } 24 25 data = {} 26 27 def collect_data(path, value): 28 nonlocal data 29 data[path] = value 30 31 _traverse.traverse_state_dict(state_dict, collect_data) 32 33 self.assertIn(("key0",), data) 34 self.assertEqual(data[("key0",)], 1) 35 36 self.assertIn(("key1",), data) 37 self.assertEqual(data[("key1",)], [1, 2]) 38 39 self.assertIn(("key2", "1"), data) 40 self.assertEqual(data[("key2", "1")], 2) 41 self.assertIn(("key2", "2"), data) 42 self.assertEqual(data[("key2", "2")], 3) 43 44 self.assertIn(("key3",), data) 45 self.assertEqual(data[("key3",)], torch.tensor([1])) 46 47 def test_traverse_nested_list(self) -> None: 48 state_dict = { 49 "key1": [ 50 torch.tensor([1]), 51 [33, torch.tensor([2]), [44, 55]], 52 [66, 77], 53 ], 54 } 55 56 data = {} 57 58 def collect_data(path, value): 59 nonlocal data 60 data[path] = value 61 62 _traverse.traverse_state_dict(state_dict, collect_data) 63 64 self.assertNotIn(("key1"), data) 65 66 self.assertIn(("key1", 0), data) 67 self.assertEqual(data[("key1", 0)], torch.tensor([1])) 68 69 self.assertIn(("key1", 1, 0), data) 70 self.assertEqual(data[("key1", 1, 0)], 33) 71 72 self.assertIn(("key1", 1, 1), data) 73 self.assertEqual(data[("key1", 1, 1)], torch.tensor([2])) 74 75 self.assertIn(("key1", 1, 2), data) 76 self.assertEqual(data[("key1", 1, 2)], [44, 55]) 77 self.assertNotIn(("key1", 1, 2, 0), data) 78 79 self.assertIn(("key1", 2), data) 80 self.assertEqual(data[("key1", 2)], [66, 77]) 81 82 def test_traverse_nested_dict(self) -> None: 83 state_dict = { 84 "key0": {"key1": 99, "key2": torch.tensor([1])}, 85 } 86 87 data = {} 88 89 def collect_data(path, value): 90 nonlocal data 91 data[path] = value 92 93 _traverse.traverse_state_dict(state_dict, collect_data) 94 95 self.assertNotIn(("key0",), data) 96 97 self.assertIn(("key0", "key1"), data) 98 self.assertEqual(data[("key0", "key1")], 99) 99 100 self.assertIn(("key0", "key2"), data) 101 self.assertEqual(data[("key0", "key2")], torch.tensor([1])) 102 103 def test_traverse_doesnt_ignore_intermediate_collections(self) -> None: 104 state_dict: STATE_DICT_TYPE = {"key0": [{"key1": {"key2": torch.tensor([1])}}]} 105 106 data = {} 107 108 def collect_data(path, value): 109 nonlocal data 110 data[path] = value 111 112 _traverse.traverse_state_dict(state_dict, collect_data) 113 114 self.assertIn(("key0", 0, "key1", "key2"), data) 115 self.assertEqual( 116 data[("key0", 0, "key1", "key2")], 117 torch.tensor([1]), 118 ) 119 120 def test_traverse_with_ordered_dict(self) -> None: 121 state_dict = OrderedDict( 122 { 123 "key0": [ 124 99, 125 torch.tensor([3]), 126 ] 127 } 128 ) 129 130 data = {} 131 132 def collect_data(path, value): 133 nonlocal data 134 data[path] = value 135 136 _traverse.traverse_state_dict(state_dict, collect_data) 137 138 self.assertIn(("key0", 0), data) 139 self.assertEqual(data[("key0", 0)], 99) 140 141 self.assertIn(("key0", 1), data) 142 self.assertEqual(data[("key0", 1)], torch.tensor([3])) 143 144 def test_set_element(self) -> None: 145 state_dict: STATE_DICT_TYPE = {} 146 147 _traverse.set_element(state_dict, ("k",), 10) 148 self.assertEqual(state_dict["k"], 10) 149 150 _traverse.set_element(state_dict, ("k1", 2), 1) 151 self.assertEqual(state_dict["k1"], [None, None, 1]) 152 153 _traverse.set_element(state_dict, ("k1", 1), 99) 154 self.assertEqual(state_dict["k1"], [None, 99, 1]) 155 156 _traverse.set_element(state_dict, ("k1", 3), 88) 157 self.assertEqual(state_dict["k1"], [None, 99, 1, 88]) 158 159 _traverse.set_element(state_dict, ("k2", "k3"), 3) 160 self.assertEqual(state_dict["k2"], {"k3": 3}) 161 162 _traverse.set_element(state_dict, ("k2", "k4", 0, 0), 99) 163 self.assertEqual(state_dict["k2"]["k4"][0], [99]) 164 165 def test_get_element(self) -> None: 166 state_dict = {"a": [0, 1], "b": [2, {"c": "d"}]} 167 self.assertEqual(_traverse.get_element(state_dict, ("a",)), [0, 1]) 168 self.assertEqual(_traverse.get_element(state_dict, ("b", 0)), 2) 169 self.assertEqual(_traverse.get_element(state_dict, ("b", 1, "c")), "d") 170 171 self.assertIsNone(_traverse.get_element(state_dict, ("c",))) 172 self.assertIsNone(_traverse.get_element(state_dict, ("a", 33))) 173 self.assertIsNone(_traverse.get_element(state_dict, ("b", 88))) 174 self.assertIsNone(_traverse.get_element(state_dict, ("b", 0, 2))) 175 self.assertIsNone(_traverse.get_element(state_dict, ("b", 1, 2))) 176 self.assertIsNone(_traverse.get_element(state_dict, ("b", 1, "d"))) 177 178 179if __name__ == "__main__": 180 run_tests() 181