1# Owner(s): ["oncall: distributed"] 2 3import torch 4from torch.distributed.checkpoint._nested_dict import ( 5 flatten_state_dict, 6 unflatten_state_dict, 7) 8from torch.testing._internal.common_utils import run_tests, TestCase 9 10 11class TestFlattening(TestCase): 12 def test_flattening_round_trip(self) -> None: 13 state_dict = { 14 "key0": 1, 15 "key1": [1, 2], 16 "key2": {"1": 2, "2": 3}, 17 "key3": torch.tensor([1]), 18 "key4": [[torch.tensor(2), "x"], [1, 2, 3], {"key6": [44]}], 19 } 20 21 flatten_dict, mapping = flatten_state_dict(state_dict) 22 """ 23 flatten_dict: 24 { 25 'key0': 1, 26 'key1': [1, 2], 27 'key2': {'1': 2, '2': 3}, 28 'key3': tensor([1]), 29 'key4.0.0': tensor(2), 30 'key4.0.1': 'x', 31 'key4.1': [1, 2, 3], 32 'key4.2': {'key6': [44]} 33 } 34 """ 35 restored = unflatten_state_dict(flatten_dict, mapping) 36 37 self.assertEqual(state_dict, restored) 38 39 def test_mapping(self) -> None: 40 state_dict = { 41 "k0": [1], 42 "k2": [torch.tensor([1]), 99, [{"k3": torch.tensor(1)}]], 43 "k3": ["x", 99, [{"k3": "y"}]], 44 } 45 46 flatten_dict, mapping = flatten_state_dict(state_dict) 47 """ 48 flatten_dict: 49 {'k0': [1], 'k2.0': tensor([1]), 'k2.1': 99, 'k2.2.0.k3': tensor(1), 'k3': ['x', 99, [{'k3': 'y'}]]} 50 mapping: 51 {'k0': ('k0',), 'k2.0': ('k2', 0), 'k2.1': ('k2', 1), 'k2.2.0.k3': ('k2', 2, 0, 'k3'), 'k3': ('k3',)} 52 """ 53 54 self.assertEqual(("k0",), mapping["k0"]) 55 self.assertEqual(("k2", 0), mapping["k2.0"]) 56 self.assertEqual(("k2", 1), mapping["k2.1"]) 57 self.assertEqual(("k2", 2, 0, "k3"), mapping["k2.2.0.k3"]) 58 self.assertEqual(("k3", 0), mapping["k3.0"]) 59 self.assertEqual(("k3", 1), mapping["k3.1"]) 60 self.assertEqual(("k3", 2, 0, "k3"), mapping["k3.2.0.k3"]) 61 62 63if __name__ == "__main__": 64 run_tests() 65