xref: /aosp_15_r20/external/pytorch/test/distributed/_composable/test_contract.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3from copy import deepcopy
4from typing import List, Tuple
5
6import torch
7import torch.nn as nn
8from torch.distributed._composable import _get_registry, contract
9from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
10
11
12class ToyModel(nn.Module):
13    def __init__(self) -> None:
14        super().__init__()
15        self.seq1 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)])
16        self.seq2 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)])
17        self.p = nn.Parameter(torch.randn(10, 10), requires_grad=True)
18        self.b = torch.zeros(1)  # buffer
19
20    def forward(self, x, y):
21        with torch.no_grad():
22            self.b += x.sum() + y.sum()
23
24        return self.p + self.seq1(x) + self.seq2(y)
25
26
27class TestContract(TestCase):
28    @skipIfTorchDynamo("Dynamo does not support the state key")
29    def test_add_hooks(self):
30        def forward_pre_hook(
31            module: nn.Module, inp: Tuple[torch.Tensor]
32        ) -> Tuple[torch.Tensor]:
33            return inp
34
35        def forward_hook(
36            module: nn.Module, inp: Tuple[torch.Tensor], out: torch.Tensor
37        ) -> torch.Tensor:
38            return out
39
40        def backward_pre_hook(
41            module: nn.Module, grad_output: torch.Tensor
42        ) -> torch.Tensor:
43            return grad_output
44
45        def backward_hook(
46            module: nn.Module,
47            grad_input: Tuple[torch.Tensor],
48            grad_output: torch.Tensor,
49        ) -> Tuple[torch.Tensor]:
50            return grad_input
51
52        @contract()
53        def noop_api(module: nn.Module) -> nn.Module:
54            module.register_forward_pre_hook(forward_pre_hook)
55            module.register_forward_hook(forward_hook)
56            module.register_full_backward_pre_hook(backward_pre_hook)
57            module.register_full_backward_hook(backward_hook)
58            return module
59
60        model = ToyModel()
61        model_with_hooks = deepcopy(model)
62        noop_api(model.seq1)
63        noop_api(model.seq2)
64
65        x, y = torch.randn(10, 10), torch.randn(10, 10)
66        model(x, y).sum().backward()
67        model_with_hooks(x, y).sum().backward()
68
69        for p1, p2 in zip(model.parameters(), model_with_hooks.parameters()):
70            self.assertEqual(p1, p2)
71
72    @skipIfTorchDynamo("Dynamo does not support the state key")
73    def test_modify_fqn(self):
74        class ModelWrapper(nn.Module):
75            def __init__(self, module):
76                super().__init__()
77                self.module = module
78
79            def forward(self, x):
80                return self.module(x)
81
82        @contract()
83        def wrap_module(module: nn.Module) -> nn.Module:
84            return ModelWrapper(module)
85
86        model = ToyModel()
87
88        regex = "Checking parameters: Composable distributed API implementations cannot modify FQNs."
89        with self.assertRaisesRegex(RuntimeError, regex):
90            wrap_module(model.seq1)
91
92    @skipIfTorchDynamo("Dynamo does not support the state key")
93    def test_state(self):
94        def check_and_update_state_hook(
95            module: nn.Module, inp: Tuple[torch.Tensor]
96        ) -> Tuple[torch.Tensor]:
97            self.assertEqual(api.state(module).dummy_state, 7)
98            api.state(module).dummy_state = 8
99            return inp
100
101        # FIXME: circular reference looks a bit weird. Shall we make .state a
102        # top-level API instead attached to contract API?
103        @contract()
104        def api(module: nn.Module) -> nn.Module:
105            api.state(module).dummy_state = 7
106            module.register_forward_pre_hook(check_and_update_state_hook)
107            return module
108
109        model = ToyModel()
110        api(model.seq1)
111
112        self.assertEqual(api.state(model.seq1).dummy_state, 7)
113        model(torch.zeros(10, 10), torch.zeros(10, 10))
114        self.assertEqual(api.state(model.seq1).dummy_state, 8)
115
116    @skipIfTorchDynamo("Dynamo does not support the state key")
117    def test_registry(self):
118        @contract()
119        def api1(module: nn.Module) -> nn.Module:
120            return module
121
122        @contract()
123        def api2(module: nn.Module) -> nn.Module:
124            return module
125
126        model = ToyModel()
127        model = api1(model)
128        self.assertEqual(1, len(_get_registry(model)))
129        self.assertTrue("api1" in _get_registry(model))
130        model = api2(model)
131        self.assertEqual(2, len(_get_registry(model)))
132        self.assertTrue([_get_registry(model).keys()], ["api1", "api2"])
133        self.assertEqual(None, _get_registry(model.seq1))
134        self.assertEqual(None, _get_registry(model.seq2))
135
136        with self.assertRaisesRegex(AssertionError, "api1 has already been applied"):
137            model = api1(model)
138
139    @skipIfTorchDynamo("Dynamo does not support the state key")
140    def test_multi_module_api(self):
141        @contract()
142        def multi_module_api(modules: List[nn.Module]) -> nn.Module:
143            return modules
144
145        model = nn.Sequential(*[nn.Linear(3, 3) for _ in range(5)])
146        multi_module_api([model[0], model[1]])
147        multi_module_api([model[2], model[3]])
148        multi_module_api([model[4]])
149        # Check that modules have the same state and registry iff they shared
150        # the same API call
151        states = [multi_module_api.state(module) for module in model]
152        self.assertEqual(states[0], states[1])
153        self.assertEqual(states[2], states[3])
154        self.assertNotEqual(states[0], states[2])
155        self.assertNotEqual(states[0], states[4])
156        self.assertNotEqual(states[2], states[4])
157        registries = [_get_registry(module) for module in model]
158        self.assertEqual(registries[0], registries[1])
159        self.assertEqual(registries[2], registries[3])
160        self.assertNotEqual(registries[0], registries[2])
161        self.assertNotEqual(registries[0], registries[4])
162        self.assertNotEqual(registries[2], registries[4])
163        # Check that applying an API to a module multiple times errors
164        model = nn.Sequential(*[nn.Linear(3, 3) for _ in range(5)])
165        multi_module_api([model[0], model[1]])
166        with self.assertRaisesRegex(
167            AssertionError,
168            "Each distinct composable distributed API can only be applied to "
169            r"a module once. multi_module_api has already been applied to the "
170            "following module:",
171        ):
172            multi_module_api([model[0], model[2]])
173
174
175if __name__ == "__main__":
176    run_tests()
177