1# Owner(s): ["oncall: distributed"] 2 3from copy import deepcopy 4from typing import List, Tuple 5 6import torch 7import torch.nn as nn 8from torch.distributed._composable import _get_registry, contract 9from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase 10 11 12class ToyModel(nn.Module): 13 def __init__(self) -> None: 14 super().__init__() 15 self.seq1 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)]) 16 self.seq2 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)]) 17 self.p = nn.Parameter(torch.randn(10, 10), requires_grad=True) 18 self.b = torch.zeros(1) # buffer 19 20 def forward(self, x, y): 21 with torch.no_grad(): 22 self.b += x.sum() + y.sum() 23 24 return self.p + self.seq1(x) + self.seq2(y) 25 26 27class TestContract(TestCase): 28 @skipIfTorchDynamo("Dynamo does not support the state key") 29 def test_add_hooks(self): 30 def forward_pre_hook( 31 module: nn.Module, inp: Tuple[torch.Tensor] 32 ) -> Tuple[torch.Tensor]: 33 return inp 34 35 def forward_hook( 36 module: nn.Module, inp: Tuple[torch.Tensor], out: torch.Tensor 37 ) -> torch.Tensor: 38 return out 39 40 def backward_pre_hook( 41 module: nn.Module, grad_output: torch.Tensor 42 ) -> torch.Tensor: 43 return grad_output 44 45 def backward_hook( 46 module: nn.Module, 47 grad_input: Tuple[torch.Tensor], 48 grad_output: torch.Tensor, 49 ) -> Tuple[torch.Tensor]: 50 return grad_input 51 52 @contract() 53 def noop_api(module: nn.Module) -> nn.Module: 54 module.register_forward_pre_hook(forward_pre_hook) 55 module.register_forward_hook(forward_hook) 56 module.register_full_backward_pre_hook(backward_pre_hook) 57 module.register_full_backward_hook(backward_hook) 58 return module 59 60 model = ToyModel() 61 model_with_hooks = deepcopy(model) 62 noop_api(model.seq1) 63 noop_api(model.seq2) 64 65 x, y = torch.randn(10, 10), torch.randn(10, 10) 66 model(x, y).sum().backward() 67 model_with_hooks(x, y).sum().backward() 68 69 for p1, p2 in zip(model.parameters(), model_with_hooks.parameters()): 70 self.assertEqual(p1, p2) 71 72 @skipIfTorchDynamo("Dynamo does not support the state key") 73 def test_modify_fqn(self): 74 class ModelWrapper(nn.Module): 75 def __init__(self, module): 76 super().__init__() 77 self.module = module 78 79 def forward(self, x): 80 return self.module(x) 81 82 @contract() 83 def wrap_module(module: nn.Module) -> nn.Module: 84 return ModelWrapper(module) 85 86 model = ToyModel() 87 88 regex = "Checking parameters: Composable distributed API implementations cannot modify FQNs." 89 with self.assertRaisesRegex(RuntimeError, regex): 90 wrap_module(model.seq1) 91 92 @skipIfTorchDynamo("Dynamo does not support the state key") 93 def test_state(self): 94 def check_and_update_state_hook( 95 module: nn.Module, inp: Tuple[torch.Tensor] 96 ) -> Tuple[torch.Tensor]: 97 self.assertEqual(api.state(module).dummy_state, 7) 98 api.state(module).dummy_state = 8 99 return inp 100 101 # FIXME: circular reference looks a bit weird. Shall we make .state a 102 # top-level API instead attached to contract API? 103 @contract() 104 def api(module: nn.Module) -> nn.Module: 105 api.state(module).dummy_state = 7 106 module.register_forward_pre_hook(check_and_update_state_hook) 107 return module 108 109 model = ToyModel() 110 api(model.seq1) 111 112 self.assertEqual(api.state(model.seq1).dummy_state, 7) 113 model(torch.zeros(10, 10), torch.zeros(10, 10)) 114 self.assertEqual(api.state(model.seq1).dummy_state, 8) 115 116 @skipIfTorchDynamo("Dynamo does not support the state key") 117 def test_registry(self): 118 @contract() 119 def api1(module: nn.Module) -> nn.Module: 120 return module 121 122 @contract() 123 def api2(module: nn.Module) -> nn.Module: 124 return module 125 126 model = ToyModel() 127 model = api1(model) 128 self.assertEqual(1, len(_get_registry(model))) 129 self.assertTrue("api1" in _get_registry(model)) 130 model = api2(model) 131 self.assertEqual(2, len(_get_registry(model))) 132 self.assertTrue([_get_registry(model).keys()], ["api1", "api2"]) 133 self.assertEqual(None, _get_registry(model.seq1)) 134 self.assertEqual(None, _get_registry(model.seq2)) 135 136 with self.assertRaisesRegex(AssertionError, "api1 has already been applied"): 137 model = api1(model) 138 139 @skipIfTorchDynamo("Dynamo does not support the state key") 140 def test_multi_module_api(self): 141 @contract() 142 def multi_module_api(modules: List[nn.Module]) -> nn.Module: 143 return modules 144 145 model = nn.Sequential(*[nn.Linear(3, 3) for _ in range(5)]) 146 multi_module_api([model[0], model[1]]) 147 multi_module_api([model[2], model[3]]) 148 multi_module_api([model[4]]) 149 # Check that modules have the same state and registry iff they shared 150 # the same API call 151 states = [multi_module_api.state(module) for module in model] 152 self.assertEqual(states[0], states[1]) 153 self.assertEqual(states[2], states[3]) 154 self.assertNotEqual(states[0], states[2]) 155 self.assertNotEqual(states[0], states[4]) 156 self.assertNotEqual(states[2], states[4]) 157 registries = [_get_registry(module) for module in model] 158 self.assertEqual(registries[0], registries[1]) 159 self.assertEqual(registries[2], registries[3]) 160 self.assertNotEqual(registries[0], registries[2]) 161 self.assertNotEqual(registries[0], registries[4]) 162 self.assertNotEqual(registries[2], registries[4]) 163 # Check that applying an API to a module multiple times errors 164 model = nn.Sequential(*[nn.Linear(3, 3) for _ in range(5)]) 165 multi_module_api([model[0], model[1]]) 166 with self.assertRaisesRegex( 167 AssertionError, 168 "Each distinct composable distributed API can only be applied to " 169 r"a module once. multi_module_api has already been applied to the " 170 "following module:", 171 ): 172 multi_module_api([model[0], model[2]]) 173 174 175if __name__ == "__main__": 176 run_tests() 177