xref: /aosp_15_r20/external/pytorch/test/distributed/_composable/fsdp/test_fully_shard_frozen.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import copy
4import functools
5import itertools
6from typing import List, Union
7
8import torch
9import torch.distributed as dist
10import torch.nn as nn
11import torch.nn.functional as F
12from torch.distributed._composable import checkpoint, replicate
13from torch.distributed._composable.fsdp import fully_shard
14from torch.distributed._composable.fsdp._fsdp_param_group import (
15    RegisterPostBackwardFunction,
16)
17from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
18from torch.testing._internal.common_fsdp import (
19    check_sharded_parity,
20    FSDPTest,
21    MLP,
22    patch_reduce_scatter,
23    patch_register_post_backward_hook_backward,
24    reduce_scatter_with_assert,
25)
26from torch.testing._internal.common_utils import run_tests
27
28
29class TestFullyShardFrozen(FSDPTest):
30    @property
31    def world_size(self) -> int:
32        return min(4, torch.cuda.device_count())
33
34    @skip_if_lt_x_gpu(2)
35    def test_train_mixed_requires_grad_per_group(self):
36        """
37        Tests training parity with DDP when mixing frozen and non-frozen
38        parameters in the same FSDP communication group. This checks that
39        the reduce-scatters reduce the expected numel and that they are called
40        via the custom autograd function backward (i.e. that they are not
41        delayed until the end of backward).
42        """
43        self.run_subtests(
44            {
45                "reshard_after_forward": [False, True, 2],
46                "use_activation_checkpointing": [False, True],
47                "freeze_after_init": [False, True],
48            },
49            self._test_train_mixed_requires_grad_per_group,
50        )
51
52    def _test_train_mixed_requires_grad_per_group(
53        self,
54        reshard_after_forward: Union[bool, int],
55        use_activation_checkpointing: bool,
56        freeze_after_init: bool,
57    ):
58        torch.manual_seed(42)
59        num_mlps, lin_dim = (3, 32)
60        model = nn.Sequential(
61            *[MLP(lin_dim, torch.device("cpu")) for _ in range(num_mlps)]
62        )
63        # Train biases only (e.g. like BitFit)
64        if not freeze_after_init:
65            for param_name, param in model.named_parameters():
66                if "bias" not in param_name:
67                    param.requires_grad_(False)
68        ref_model = replicate(
69            copy.deepcopy(model).cuda(),
70            device_ids=[self.rank],
71            find_unused_parameters=freeze_after_init,
72        )
73        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
74        for mlp in model:
75            if use_activation_checkpointing:
76                checkpoint(mlp)
77            fully_shard(mlp, reshard_after_forward=reshard_after_forward)
78        fully_shard(model, reshard_after_forward=reshard_after_forward)
79        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
80        orig_reduce_scatter = dist.reduce_scatter_tensor
81        if freeze_after_init:
82            for param_name, param in itertools.chain(
83                model.named_parameters(), ref_model.named_parameters()
84            ):
85                if "bias" not in param_name:
86                    param.requires_grad_(False)
87        for mlp in model:
88            assert isinstance(mlp, MLP), (
89                "The reduce-scatter numel check assumes the model consists of "
90                f"only the same MLP class but got {type(mlp)}"
91            )
92        expected_numel = sum(
93            p._local_tensor.numel()
94            for n, p in model[0].named_parameters()
95            if "bias" in n
96        )
97
98        def assert_fn(output: torch.Tensor):
99            self.assertEqual(output.numel(), expected_numel)
100
101        reduce_scatter = functools.partial(
102            reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
103        )
104        orig_backward = RegisterPostBackwardFunction.backward
105        backward_count = 0
106
107        def backward_with_count(*args, **kwargs):
108            nonlocal backward_count
109            backward_count += 1
110            return orig_backward(*args, **kwargs)
111
112        torch.manual_seed(42 + self.rank + 1)
113        device = torch.device("cuda")
114        with patch_reduce_scatter(
115            reduce_scatter
116        ), patch_register_post_backward_hook_backward(backward_with_count):
117            for iter_idx in range(10):
118                inp = torch.randn((8, lin_dim), device=device)
119                losses: List[torch.Tensor] = []
120                for _model, _optim in ((ref_model, ref_optim), (model, optim)):
121                    _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
122                    losses.append(_model(inp).sum())
123                    losses[-1].backward()
124                    _optim.step()
125                check_sharded_parity(self, ref_model, model)
126                self.assertEqual(losses[0], losses[1])
127                # Check that the post-backward hooks ran through the autograd
128                # backward, not the final callback (except possibly that of the
129                # first MLP, which does not have an input that requires grad)
130                self.assertTrue(backward_count >= num_mlps - 1)
131
132    @skip_if_lt_x_gpu(2)
133    def test_train_mixed_requires_grad_across_groups(self):
134        """
135        Tests training parity with DDP when mixing frozen and non-frozen
136        parameters across different FSDP communication groups, including
137        possibly unfreezing parameters.
138        """
139        self.run_subtests(
140            {
141                "reshard_after_forward": [False, True, 2],
142                "unfreeze_params": [False, True],
143            },
144            self._test_train_mixed_requires_grad_across_groups,
145        )
146
147    def _test_train_mixed_requires_grad_across_groups(
148        self,
149        reshard_after_forward: Union[bool, int],
150        unfreeze_params: bool,
151    ):
152        torch.manual_seed(42)
153        num_linears, lin_dim = (6, 32)
154        modules: List[nn.Module] = []
155        for _ in range(num_linears):
156            modules += [nn.Linear(lin_dim, lin_dim), nn.ReLU()]
157        model = nn.Sequential(*modules)
158        ref_model = replicate(
159            copy.deepcopy(model).cuda(),
160            device_ids=[self.rank],
161            find_unused_parameters=True,
162        )
163        for module in model.modules():
164            if isinstance(module, nn.Linear):
165                fully_shard(module, reshard_after_forward=reshard_after_forward)
166        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
167        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
168        orig_backward = RegisterPostBackwardFunction.backward
169        backward_count = 0
170
171        def _set_requires_grad(seq: nn.Module, requires_grad: bool):
172            for i in range(num_linears):
173                # Interleave frozen -> non-frozen -> ... linears
174                if i % 2 == 0:
175                    for param in seq[i % 2].parameters():
176                        param.requires_grad_(requires_grad)
177
178        def backward_with_count(*args, **kwargs):
179            nonlocal backward_count
180            backward_count += 1
181            return orig_backward(*args, **kwargs)
182
183        _set_requires_grad(model, False)
184        _set_requires_grad(ref_model, False)
185        num_iters, no_grad_iter_idx = (3, 1)
186        torch.manual_seed(42 + self.rank)
187        inp = torch.randn((8, lin_dim), device="cuda")
188        with patch_register_post_backward_hook_backward(backward_with_count):
189            for iter_idx in range(num_iters):
190                losses: List[torch.Tensor] = []
191                for _model, _optim in ((ref_model, ref_optim), (model, optim)):
192                    # Unfreeze the parameters on the last step to emulate some
193                    # kinds of fine-tuning
194                    if unfreeze_params and iter_idx == num_iters - 1:
195                        _set_requires_grad(model, True)
196                    if iter_idx == no_grad_iter_idx:
197                        with torch.no_grad():
198                            losses.append(_model(inp).sum())
199                    else:
200                        losses.append(_model(inp).sum())
201                        losses[-1].backward()
202                        _optim.step()
203                        _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
204            self.assertEqual(losses[0], losses[1])
205            # Check that the post-backward hooks ran through the autograd
206            # backward, not the final callback (except possibly that of the
207            # first linear, which does not have an input that requires grad)
208            self.assertTrue(backward_count >= num_linears - 1)
209
210    @skip_if_lt_x_gpu(2)
211    def test_multi_forward_mixed_requires_grad(self):
212        """
213        Tests training parity with DDP when having trainable and frozen modules
214        that participate multiple times in forward.
215        """
216        self.run_subtests(
217            {"reshard_after_forward": [True, False, 2]},
218            self._test_multi_forward_mixed_requires_grad,
219        )
220
221    def _test_multi_forward_mixed_requires_grad(
222        self,
223        reshard_after_forward: Union[bool, int],
224    ):
225        class MultiForwardModule(nn.Module):
226            def __init__(self, device: torch.device):
227                super().__init__()
228                self.layer_0 = nn.Linear(5, 5, device=device)
229                self.layer_no_grad = nn.Linear(5, 5, device=device)
230                self.layer_with_grad = nn.Linear(5, 5, device=device)
231                self.layer_no_grad.requires_grad_(False)
232
233            def forward(self, x: torch.Tensor) -> torch.Tensor:
234                x = self.layer_0(x)
235                for _ in range(3):
236                    x = self.layer_no_grad(F.relu(self.layer_with_grad(x)))
237                    # Make sure that calling the same layer multiple times
238                    # works regardless whether gradient is enabled
239                    with torch.no_grad():
240                        x += F.relu(self.layer_with_grad(x))
241                return x
242
243        torch.manual_seed(42)
244        model = MultiForwardModule(torch.device("cpu"))
245        ref_model = replicate(copy.deepcopy(model).cuda(), device_ids=[self.rank])
246        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
247        for module in model.modules():
248            if isinstance(module, nn.Linear):
249                fully_shard(module, reshard_after_forward=reshard_after_forward)
250        fully_shard(model, reshard_after_forward=reshard_after_forward)
251        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
252        for iter_idx in range(10):
253            inp = torch.randn((8, 5), device="cuda")
254            losses: List[torch.Tensor] = []
255            for _model, _optim in ((ref_model, ref_optim), (model, optim)):
256                _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
257                losses.append(_model(inp).sum())
258                losses[-1].backward()
259                _optim.step()
260            self.assertEqual(losses[0], losses[1])
261
262
263if __name__ == "__main__":
264    run_tests()
265