xref: /aosp_15_r20/external/pytorch/test/distributed/_composable/test_compose.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import copy
4import sys
5from typing import Dict
6
7import torch
8import torch.distributed as dist
9import torch.nn as nn
10from torch.distributed._composable import checkpoint, fully_shard, replicate
11from torch.distributed._shard.sharded_tensor import ShardedTensor
12from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
13from torch.distributed.fsdp.api import MixedPrecision, ShardingStrategy
14from torch.distributed.fsdp.wrap import ModuleWrapPolicy
15from torch.testing._internal.common_dist_composable import (
16    CompositeModel,
17    CompositeParamModel,
18    UnitModule,
19)
20from torch.testing._internal.common_distributed import (
21    SaveForwardInputsModel,
22    skip_if_lt_x_gpu,
23)
24from torch.testing._internal.common_fsdp import FSDPTest
25from torch.testing._internal.common_utils import (
26    instantiate_parametrized_tests,
27    run_tests,
28    TEST_WITH_DEV_DBG_ASAN,
29)
30
31
32if not dist.is_available():
33    print("Distributed not available, skipping tests", file=sys.stderr)
34    sys.exit(0)
35
36
37if TEST_WITH_DEV_DBG_ASAN:
38    print(
39        "Skip dev-asan as torch + multiprocessing spawn have known issues",
40        file=sys.stderr,
41    )
42    sys.exit(0)
43
44
45class TestFSDPCheckpoint(FSDPTest):
46    @property
47    def world_size(self) -> int:
48        return 2
49
50    # TODO: Define `use_same_inputs_across_ranks` for now for BC since some
51    # test model configs do not have a simple base model to compare against. In
52    # those cases, we use the same inputs across ranks so that the averaged
53    # gradient equals the local gradient to check for parity. This means that
54    # the gradient reduction is unchecked.
55    def _test_parity(
56        self,
57        base_model: nn.Module,
58        test_model: nn.Module,
59        inp_size: torch.Size,
60        inp_device: torch.device,
61        grad_to_none: bool,
62        use_same_inputs_across_ranks: bool,
63    ):
64        LR = 0.01
65        base_optim = torch.optim.Adam(base_model.parameters(), lr=LR)
66        test_optim = torch.optim.Adam(test_model.parameters(), lr=LR)
67
68        for _ in range(5):
69            if use_same_inputs_across_ranks:
70                torch.manual_seed(0)
71            x = torch.randn(inp_size, device=inp_device)
72            test_loss = test_model(x).sum()
73            base_loss = base_model(x).sum()
74
75            self.assertEqual(test_loss, base_loss)
76
77            test_loss.backward()
78            test_optim.step()
79            test_optim.zero_grad(set_to_none=grad_to_none)
80
81            base_loss.backward()
82            base_optim.step()
83            base_optim.zero_grad(set_to_none=grad_to_none)
84
85    @skip_if_lt_x_gpu(2)
86    def test_wrap_same_submodule(self):
87        model = UnitModule(device=torch.device("cuda"))
88
89        base_model = copy.deepcopy(model)
90
91        test_model = copy.deepcopy(model)
92        # compose checkpoint and fully_shard
93        test_model.seq = checkpoint(test_model.seq)
94        test_model.seq = fully_shard(
95            test_model.seq,
96            policy=ModuleWrapPolicy({nn.Linear}),
97        )
98
99        self.run_subtests(
100            {
101                "base_model": [base_model],
102                "test_model": [test_model],
103                "inp_size": [torch.Size((2, 100))],
104                "inp_device": [torch.device("cuda")],
105                "grad_to_none": [True, False],
106                "use_same_inputs_across_ranks": [True],
107            },
108            self._test_parity,
109        )
110
111    def _test_checkpoint_fsdp_submodules(self):
112        model = CompositeModel(device=torch.device("cuda"))
113
114        base_model = copy.deepcopy(model)
115
116        test_model = copy.deepcopy(model)
117        test_model.u1 = fully_shard(test_model.u1, policy=None)
118        test_model.u2 = fully_shard(test_model.u2)
119
120        test_model.u1.seq = checkpoint(test_model.u1.seq)
121        test_model.u2.seq = checkpoint(test_model.u2.seq)
122
123        self.run_subtests(
124            {
125                "base_model": [base_model],
126                "test_model": [test_model],
127                "inp_size": [torch.Size((2, 100))],
128                "inp_device": [torch.device("cuda")],
129                "grad_to_none": [True, False],
130                "use_same_inputs_across_ranks": [True],
131            },
132            self._test_parity,
133        )
134
135    @skip_if_lt_x_gpu(2)
136    def test_checkpoint_fsdp_submodules_non_reentrant(self):
137        self._test_checkpoint_fsdp_submodules()
138
139    @skip_if_lt_x_gpu(2)
140    def test_checkpoint_fully_shard_cast_forward_inputs(self):
141        self.run_subtests(
142            {
143                "checkpoint_strict_submodule": [False, True],
144            },
145            self._test_checkpoint_fully_shard_cast_forward_inputs,
146        )
147
148    def _test_checkpoint_fully_shard_cast_forward_inputs(
149        self, checkpoint_strict_submodule: bool
150    ):
151        forward_inputs: Dict[nn.Module, torch.Tensor] = {}
152        fp16_mp = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True)
153        fp32_mp = MixedPrecision(param_dtype=torch.float32, cast_forward_inputs=True)
154
155        model = SaveForwardInputsModel(
156            forward_inputs=forward_inputs, cast_forward_inputs=False
157        ).cuda()
158        x = torch.zeros(2, 100, device="cuda")
159
160        fully_shard(model.c2, mixed_precision=fp16_mp)
161        if checkpoint_strict_submodule:
162            checkpoint(model.c2.l)
163        else:
164            checkpoint(model.c2)
165        fully_shard(model, mixed_precision=fp32_mp)
166
167        loss = model(x).sum()
168        loss.backward()
169
170        self.assertEqual(forward_inputs[model].dtype, torch.float32)
171        self.assertEqual(forward_inputs[model.c1].dtype, torch.float32)
172        # Notably, check that the recomputed forward preserves the right dtype
173        self.assertEqual(forward_inputs[model.c2].dtype, torch.float16)
174
175    @skip_if_lt_x_gpu(2)
176    def test_fully_shard_replicate_correct_replicate_params(self):
177        model = CompositeParamModel(device=torch.device("cuda"))
178        # Shard Linears within UnitModule
179        fully_shard(model.u1, policy=ModuleWrapPolicy({nn.Linear}))
180        fully_shard(model.u2, policy=ModuleWrapPolicy({nn.Linear}))
181        # replicate the rest
182        replicate(model)
183        # Run fwd + bwd to initialize DDP
184        inp = torch.randn(2, 100, device="cuda")
185        model(inp).sum().backward()
186        # Ensure replicate param names are as expected, i.e.
187        # immediate parameters of model and parameters of model's non-UnitModule
188        # submodules are replicated
189        param_names = replicate.state(model)._param_names
190        replicated_modules = [
191            (name, mod)
192            for (name, mod) in model.named_children()
193            if mod not in [model.u1, model.u2]
194        ]
195        replicated_param_names = [
196            f"{module_name}.{n}"
197            for module_name, mod in replicated_modules
198            for n, _ in mod.named_parameters()
199        ]
200        replicated_param_names.extend(
201            [n for n, _ in model.named_parameters(recurse=False)]
202        )
203        self.assertEqual(set(param_names), set(replicated_param_names))
204
205    @skip_if_lt_x_gpu(2)
206    def test_checkpoint_fsdp_submodules_with_param(self):
207        model = CompositeParamModel(device=torch.device("cuda"))
208
209        base_model = copy.deepcopy(model)
210
211        test_model = copy.deepcopy(model)
212        test_model.u1.seq = checkpoint(test_model.u1.seq)
213        test_model.u2.seq = checkpoint(test_model.u2.seq)
214        test_model = fully_shard(test_model)
215
216        self.run_subtests(
217            {
218                "base_model": [base_model],
219                "test_model": [test_model],
220                "inp_size": [torch.Size((2, 100))],
221                "inp_device": [torch.device("cuda")],
222                "grad_to_none": [True, False],
223                "use_same_inputs_across_ranks": [True],
224            },
225            self._test_parity,
226        )
227
228    @skip_if_lt_x_gpu(2)
229    def test_checkpoint_fsdp_submodules_with_param_no_shard(self):
230        model = CompositeParamModel(device=torch.device("cuda"))
231
232        base_model = copy.deepcopy(model)
233
234        test_model = copy.deepcopy(model)
235        test_model.u1.seq = checkpoint(test_model.u1.seq)
236        test_model.u2.seq = checkpoint(test_model.u2.seq)
237        test_model = fully_shard(test_model, strategy=ShardingStrategy.NO_SHARD)
238
239        self.run_subtests(
240            {
241                "base_model": [base_model],
242                "test_model": [test_model],
243                "inp_size": [torch.Size((2, 100))],
244                "inp_device": [torch.device("cuda")],
245                "grad_to_none": [True, False],
246                "use_same_inputs_across_ranks": [True],
247            },
248            self._test_parity,
249        )
250
251    @skip_if_lt_x_gpu(2)
252    def test_composable_fsdp_replicate(self):
253        # Verify how the APIs can be composed, e.g. if both `fully_shard` and
254        # `replicate` are applied on the same module, it should raise exception.
255        model = CompositeModel(device=torch.device("cuda"))
256        fully_shard(model.l1)
257        with self.assertRaisesRegex(RuntimeError, "Cannot apply .*replicate"):
258            replicate(model.l1)
259        replicate(model.l2)  # should not raise
260
261    @skip_if_lt_x_gpu(2)
262    def test_fully_shard_replicate_composability(self):
263        """
264        Tests composing ``fully_shard`` and ``replicate``. To save unit test
265        time, we run the different configs in subtests.
266        """
267        self.run_subtests(
268            {
269                "config": [
270                    "1fm,1r",
271                    "1r,1fm",
272                    "1r,1fa",
273                    "1r1fm,1fm",
274                    "1r1fa,1fm",
275                    "1fm1fm,1r1r,1fm",
276                ]
277            },
278            self._test_replicate_in_fully_shard,
279        )
280
281    def _test_replicate_in_fully_shard(self, config: str):
282        """
283        To interpret the config, each comma delineates a level in the module
284        tree ordered bottom-up; 'r' means ``replicate``; 'f' means
285        ``fully_shard``; 'a' means auto wrap; and 'm' means manual wrap.
286        """
287        # Set the seed to ensure that all ranks initialize the same model
288        torch.manual_seed(0)
289        if config == "1fm,1r":
290            base_model = CompositeModel(device=torch.device("cuda"))
291            test_model = copy.deepcopy(base_model)
292            fully_shard(test_model.l1)
293            replicate(test_model)
294        elif config == "1r,1fm":
295            base_model = CompositeParamModel(torch.device("cuda"))
296            test_model = copy.deepcopy(base_model)
297            replicate(test_model.u1)
298            fully_shard(test_model)
299        elif config == "1r,1fa":
300            base_model = CompositeParamModel(torch.device("cuda"))
301            test_model = copy.deepcopy(base_model)
302            replicate(test_model.u1)
303            fully_shard(test_model, policy=ModuleWrapPolicy({UnitModule}))
304        elif config == "1r1fm,1fm":
305            base_model = CompositeParamModel(torch.device("cuda"))
306            test_model = copy.deepcopy(base_model)
307            replicate(test_model.u1)
308            fully_shard(test_model.u2)
309            fully_shard(test_model)
310        elif config == "1r1fa,1fm":
311            base_model = CompositeParamModel(torch.device("cuda"))
312            test_model = copy.deepcopy(base_model)
313            replicate(test_model.u1)
314            fully_shard(test_model.u2, policy=ModuleWrapPolicy({UnitModule}))
315            fully_shard(test_model)
316        elif config == "1fm1fm,1r1r,1fm":
317            base_model = CompositeParamModel(torch.device("cuda"))
318            test_model = copy.deepcopy(base_model)
319            fully_shard(test_model.u1.seq)
320            fully_shard(test_model.u2.seq)
321            replicate(test_model.u1)
322            replicate(test_model.u2)
323            fully_shard(test_model)
324        else:
325            raise ValueError(f"Unknown config: {config}")
326        # Apply data parallelism to the base model for parity since we apply
327        # data parallelism to the test model
328        replicate(base_model)
329
330        # Set the seed to ensure that ranks get different input data
331        torch.manual_seed(self.rank + 1)
332        self._test_parity(
333            base_model,
334            test_model,
335            torch.Size((2, 100)),
336            torch.device("cuda"),
337            True,
338            False,
339        )
340
341    @skip_if_lt_x_gpu(2)
342    def test_state_dict_fsdp_submodules(self):
343        model = CompositeModel(device=torch.device("cuda"))
344
345        full_shard_args = {"strategy": ShardingStrategy.FULL_SHARD}
346        no_shard_args = {"strategy": ShardingStrategy.NO_SHARD}
347
348        model.u1 = fully_shard(model.u1, **full_shard_args)
349        model.u2 = fully_shard(model.u2, **no_shard_args)
350
351        FSDP.set_state_dict_type(
352            model,
353            StateDictType.SHARDED_STATE_DICT,
354        )
355
356        state_dict = model.state_dict()
357        for fqn, tensor in state_dict.items():
358            if "u1" in fqn:
359                self.assertIsInstance(tensor, ShardedTensor)
360            elif "u2" in fqn:
361                self.assertIsInstance(tensor, torch.Tensor)
362        # Ensure that get_state_dict_type can still correctly get the settings.
363        _ = FSDP.get_state_dict_type(model)
364
365
366instantiate_parametrized_tests(TestFSDPCheckpoint)
367
368
369if __name__ == "__main__":
370    run_tests()
371