xref: /aosp_15_r20/external/pytorch/test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import copy
4import functools
5from typing import Dict, List, Optional, Union
6
7import torch
8import torch.distributed as dist
9import torch.distributed._functional_collectives as funcol
10import torch.nn as nn
11from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
12from torch.distributed._composable.fsdp._fsdp_collectives import (
13    _get_gradient_divide_factors,
14)
15from torch.testing._internal.common_distributed import (
16    requires_nccl_version,
17    SaveForwardInputsModel,
18    skip_if_lt_x_gpu,
19)
20from torch.testing._internal.common_fsdp import (
21    check_sharded_parity,
22    FSDPTest,
23    FSDPTestMultiThread,
24    MLP,
25    patch_reduce_scatter,
26    reduce_scatter_with_assert,
27)
28from torch.testing._internal.common_utils import run_tests
29
30
31class TestFullyShardMixedPrecisionTraining(FSDPTest):
32    @property
33    def world_size(self) -> int:
34        return min(4, torch.cuda.device_count())
35
36    def _init_models_and_optims(
37        self,
38        reshard_after_forward: Union[bool, int],
39        param_dtype: Optional[torch.dtype],
40        reduce_dtype: Optional[torch.dtype],
41    ):
42        torch.manual_seed(42)
43        model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)])
44        ref_model = copy.deepcopy(model).cuda()
45        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
46        mp_policy = MixedPrecisionPolicy(
47            param_dtype=param_dtype, reduce_dtype=reduce_dtype
48        )
49        fully_shard_fn = functools.partial(
50            fully_shard,
51            reshard_after_forward=reshard_after_forward,
52            mp_policy=mp_policy,
53        )
54        for mlp in model:
55            fully_shard_fn(mlp)
56        fully_shard_fn(model)
57        optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
58        return ref_model, ref_optim, model, optim
59
60    @skip_if_lt_x_gpu(2)
61    @requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
62    def test_compute_dtype(self):
63        self.run_subtests(
64            {
65                "param_dtype": [torch.bfloat16, torch.float16],
66                "reshard_after_forward": [False, True, 2],
67            },
68            self._test_compute_dtype,
69        )
70
71    def _test_compute_dtype(
72        self, param_dtype: torch.dtype, reshard_after_forward: Union[bool, int]
73    ):
74        ref_model, ref_optim, model, optim = self._init_models_and_optims(
75            reshard_after_forward, param_dtype=param_dtype, reduce_dtype=None
76        )
77        ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
78        orig_reduce_scatter = dist.reduce_scatter_tensor
79
80        def assert_fn(output: torch.Tensor):
81            self.assertEqual(output.dtype, param_dtype)
82
83        reduce_scatter = functools.partial(
84            reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
85        )
86        predivide_factor, postdivide_factor = _get_gradient_divide_factors(
87            self.process_group, all_reduce_group=None, reduce_dtype=param_dtype
88        )
89
90        torch.manual_seed(42 + self.rank + 1)
91        inp = torch.randn((4, 16), device="cuda", dtype=param_dtype)
92        for iter_idx in range(10):
93            optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
94            fsdp_loss = model(inp).sum()
95            with patch_reduce_scatter(reduce_scatter):
96                fsdp_loss.backward()
97            optim.step()
98
99            ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
100            ref_loss = ref_model_bf16(inp.to(param_dtype)).sum()
101            ref_loss.backward()
102            for param in ref_model_bf16.parameters():
103                # Use reduce-scatter -> all-gather as all-reduce because for
104                # world size >=4, NCCL all-reduce shows numeric differences
105                # compared with NCCL reduce-scatter
106                if predivide_factor is not None and predivide_factor > 1:
107                    param.grad.div_(predivide_factor)
108                elif predivide_factor is None:
109                    param.grad.div_(self.world_size)
110                output = torch.zeros_like(torch.chunk(param.grad, self.world_size)[0])
111                dist.reduce_scatter_tensor(output, param.grad)
112                dist.all_gather_into_tensor(param.grad, output)
113                if postdivide_factor is not None and postdivide_factor > 1:
114                    param.grad.div_(postdivide_factor)
115            for param_fp32, param_bf16 in zip(
116                ref_model.parameters(), ref_model_bf16.parameters()
117            ):
118                param_fp32.grad = param_bf16.grad.to(param_fp32.dtype)
119                param_bf16.grad = None
120            ref_optim.step()  # fp32 optimizer step
121            for param_fp32, param_bf16 in zip(
122                ref_model.parameters(), ref_model_bf16.parameters()
123            ):
124                param_bf16.detach().copy_(param_fp32)
125
126            self.assertEqual(fsdp_loss, ref_loss)
127            check_sharded_parity(self, ref_model, model)
128
129    @skip_if_lt_x_gpu(2)
130    @requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
131    def test_reduce_dtype(self):
132        self.run_subtests(
133            {"reshard_after_forward": [False, True, 2]},
134            self._test_reduce_dtype_fp32_reduce,
135        )
136        self.run_subtests(
137            {"reshard_after_forward": [False, True, 2]},
138            self._test_reduce_dtype_bf16_reduce,
139        )
140
141    def _test_reduce_dtype_fp32_reduce(self, reshard_after_forward: Union[bool, int]):
142        param_dtype, reduce_dtype = torch.bfloat16, torch.float32
143        ref_model, ref_optim, model, optim = self._init_models_and_optims(
144            reshard_after_forward, param_dtype=param_dtype, reduce_dtype=reduce_dtype
145        )
146        ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
147        orig_reduce_scatter = dist.reduce_scatter_tensor
148
149        def assert_fn(output: torch.Tensor):
150            self.assertEqual(output.dtype, reduce_dtype)
151
152        reduce_scatter = functools.partial(
153            reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
154        )
155        torch.manual_seed(42 + self.rank + 1)
156        inp = torch.randn((4, 16), device="cuda", dtype=param_dtype)
157        for iter_idx in range(10):
158            optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
159            fsdp_loss = model(inp).sum()
160            with patch_reduce_scatter(reduce_scatter):
161                fsdp_loss.backward()
162            optim.step()
163
164            ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
165            ref_loss = ref_model_bf16(inp.to(param_dtype)).sum()
166            ref_loss.backward()
167            for param in ref_model_bf16.parameters():
168                param.grad.data = param.grad.to(torch.float32)
169                dist.all_reduce(param.grad)  # fp32 reduction
170                param.grad.div_(self.world_size)
171            for param_fp32, param_bf16 in zip(
172                ref_model.parameters(), ref_model_bf16.parameters()
173            ):
174                param_fp32.grad = param_bf16.grad
175                param_bf16.grad = None
176            ref_optim.step()  # fp32 optimizer step
177            for param_fp32, param_bf16 in zip(
178                ref_model.parameters(), ref_model_bf16.parameters()
179            ):
180                param_bf16.detach().copy_(param_fp32)
181
182            self.assertEqual(fsdp_loss, ref_loss)
183            check_sharded_parity(self, ref_model, model)
184
185    def _test_reduce_dtype_bf16_reduce(self, reshard_after_forward: Union[bool, int]):
186        param_dtype, reduce_dtype = torch.float32, torch.bfloat16
187        ref_model, ref_optim, model, optim = self._init_models_and_optims(
188            reshard_after_forward, param_dtype=param_dtype, reduce_dtype=reduce_dtype
189        )
190        group = dist.distributed_c10d._get_default_group()
191        orig_reduce_scatter = dist.reduce_scatter_tensor
192
193        def assert_fn(output: torch.Tensor):
194            self.assertEqual(output.dtype, reduce_dtype)
195
196        reduce_scatter = functools.partial(
197            reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
198        )
199        torch.manual_seed(42 + self.rank + 1)
200        inp = torch.randn((4, 16), device="cuda", dtype=param_dtype)
201        for iter_idx in range(10):
202            optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
203            fsdp_loss = model(inp).sum()
204            with patch_reduce_scatter(reduce_scatter):
205                fsdp_loss.backward()
206            optim.step()
207
208            ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
209            ref_loss = ref_model(inp).sum()
210            ref_loss.backward()
211            for param in ref_model.parameters():
212                param_grad = param.grad.to(reduce_dtype)
213                # Use reduce-scatter -> all-gather to implement all-reduce
214                # since for world size >2, bf16 all-reduce and reduce-scatter
215                # have numeric differences
216                sharded_grad = funcol.reduce_scatter_tensor(
217                    param_grad, scatter_dim=0, reduceOp="avg", group=group
218                )  # bf16 reduction
219                param.grad = funcol.all_gather_tensor(
220                    sharded_grad, gather_dim=0, group=group
221                ).to(
222                    param.dtype
223                )  # upcast to fp32
224            ref_optim.step()  # fp32 optimizer step
225
226            self.assertEqual(fsdp_loss, ref_loss)
227            check_sharded_parity(self, ref_model, model)
228
229    @skip_if_lt_x_gpu(2)
230    def test_grad_acc_with_reduce_dtype(self):
231        """
232        Tests that gradient accumulation without reduce-scatter when using
233        bf16 compute and fp32 reduction accumulates the unsharded gradients in
234        fp32.
235        """
236        self.run_subtests(
237            {"reshard_after_forward": [True, False]},
238            self._test_grad_acc_with_reduce_dtype,
239        )
240
241    def _test_grad_acc_with_reduce_dtype(self, reshard_after_forward: bool):
242        torch.manual_seed(42)
243        param_dtype, reduce_dtype = (torch.bfloat16, torch.float32)
244        mp_policy = MixedPrecisionPolicy(
245            param_dtype=param_dtype, reduce_dtype=reduce_dtype
246        )
247        model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)])
248        # To emulate the mixed precision implementation where forward/backward
249        # compute use bf16 and optimizer uses fp32, we maintain both an fp32
250        # and a bf16 copy of the reference model
251        ref_model = copy.deepcopy(model).cuda()
252        ref_model_compute = copy.deepcopy(ref_model).to(param_dtype)
253        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
254        for mlp in model:
255            fully_shard(
256                mlp, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
257            )
258        fully_shard(
259            model, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
260        )
261        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
262        orig_reduce_scatter = dist.reduce_scatter_tensor
263
264        def assert_fn(output: torch.Tensor):
265            self.assertEqual(output.dtype, reduce_dtype)
266
267        reduce_scatter = functools.partial(
268            reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
269        )
270        torch.manual_seed(42 + self.rank + 1)
271        device = torch.device("cuda")
272        # Train on the same input to avoid loss explosion
273        num_microbatches = 4
274        inp = torch.randn((2 * num_microbatches, 16), device=device, dtype=param_dtype)
275        for iter_idx in range(10):
276            microbatch_inps = torch.chunk(inp, 4)
277            for microbatch_idx in range(num_microbatches):
278                is_last_microbatch = microbatch_idx == num_microbatches - 1
279                model.set_requires_gradient_sync(is_last_microbatch)
280                model.set_reshard_after_backward(
281                    is_last_microbatch or reshard_after_forward
282                )
283                losses: List[torch.Tensor] = []
284                for _model in (ref_model_compute, model):
285                    losses.append(
286                        _model(microbatch_inps[microbatch_idx].detach()).sum()
287                    )
288                    self.assertEqual(losses[-1].dtype, param_dtype)
289                    with patch_reduce_scatter(reduce_scatter):
290                        losses[-1].backward()
291                self.assertEqual(losses[0], losses[1])
292                # Manually accumulate gradients into the base reference model
293                # from the compute reference model in fp32
294                for ref_param, ref_param_compute in zip(
295                    ref_model.parameters(), ref_model_compute.parameters()
296                ):
297                    self.assertTrue(ref_param_compute.grad is not None)
298                    self.assertEqual(ref_param.dtype, torch.float32)
299                    if ref_param.grad is not None:
300                        ref_param.grad += ref_param_compute.grad
301                    else:
302                        ref_param.grad = ref_param_compute.grad.to(ref_param.dtype)
303                    ref_param_compute.grad = None
304                # Manually reduce gradients for the reference model on the last
305                # microbatch to implement data parallelism
306                if is_last_microbatch:
307                    for ref_param in ref_model.parameters():
308                        self.assertTrue(ref_param.grad is not None)
309                        dist.all_reduce(ref_param.grad)
310                        ref_param.grad /= self.world_size
311            check_sharded_parity(self, ref_model, model)
312            ref_optim.step()
313            optim.step()
314            ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
315            optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
316            # Manually copy parameters from the base reference model to the
317            # compute reference model to run the optimizer step for the latter
318            for ref_param, ref_param_compute in zip(
319                ref_model.parameters(), ref_model_compute.parameters()
320            ):
321                ref_param_compute.detach().copy_(ref_param)
322
323
324class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
325    @property
326    def world_size(self) -> int:
327        return 2
328
329    @skip_if_lt_x_gpu(1)
330    def test_float16_on_one_submodule(self):
331        x = torch.zeros(2, 100, device="cuda")
332
333        # Subtest 1: use fp16 on the second child submodule -- does not require
334        # any additional casting logic
335        forward_inputs: Dict[str, nn.Module] = {}
336        model = SaveForwardInputsModel(
337            forward_inputs,
338            cast_forward_inputs=False,
339        ).cuda()
340        fully_shard(model.c2, mp_policy=MixedPrecisionPolicy(param_dtype=torch.float16))
341        fully_shard(model)
342        model(x).sum().backward()
343        self.assertEqual(forward_inputs[model].dtype, torch.float32)
344        self.assertEqual(forward_inputs[model.c1].dtype, torch.float32)
345        self.assertEqual(forward_inputs[model.c2].dtype, torch.float16)
346
347        # Subtest 2: use fp16 on the second child module, where the user module
348        # owns the cast
349        forward_inputs: Dict[nn.Module, torch.Tensor] = {}
350        model = SaveForwardInputsModel(
351            forward_inputs=forward_inputs, cast_forward_inputs=True
352        ).cuda()
353        fully_shard(
354            model.c2,
355            mp_policy=MixedPrecisionPolicy(
356                param_dtype=torch.float16, cast_forward_inputs=False
357            ),
358        )
359        fully_shard(model)
360        model(x).sum().backward()
361        self.assertEqual(forward_inputs[model].dtype, torch.float32)
362        self.assertEqual(forward_inputs[model.c1].dtype, torch.float32)
363        self.assertEqual(forward_inputs[model.c2].dtype, torch.float32)
364
365        # Subtest 3: use fp16 on the first child module and specify its output
366        # dtype so that the second child module does not need to cast
367        forward_inputs: Dict[nn.Module, torch.Tensor] = {}
368        model = SaveForwardInputsModel(
369            forward_inputs=forward_inputs, cast_forward_inputs=False
370        ).cuda()
371        fully_shard(
372            model.c1,
373            mp_policy=MixedPrecisionPolicy(
374                param_dtype=torch.float16, output_dtype=torch.float32
375            ),
376        )
377        fully_shard(model)
378        model(x).sum().backward()
379        self.assertEqual(forward_inputs[model].dtype, torch.float32)
380        self.assertEqual(forward_inputs[model.c1].dtype, torch.float16)
381        self.assertEqual(forward_inputs[model.c2].dtype, torch.float32)
382
383    @skip_if_lt_x_gpu(1)
384    def test_submodules_with_external_inputs(self):
385        self.run_subtests(
386            {"enable_submodule_cast": [False, True]},
387            self._test_submodules_with_external_inputs,
388        )
389
390    def _test_submodules_with_external_inputs(self, enable_submodule_cast: bool):
391        class ToyModule(nn.Module):
392            def __init__(self, forward_inputs: Dict[str, torch.Tensor]) -> None:
393                super().__init__()
394                self.l = nn.Linear(100, 100)
395                self.forward_inputs = forward_inputs
396
397            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
398                self.forward_inputs["l2_input_x"] = x
399                self.forward_inputs["l2_input_y"] = y
400                return self.l(x)
401
402        class ToyModel(nn.Module):
403            def __init__(self, forward_inputs: Dict[str, torch.Tensor]) -> None:
404                super().__init__()
405                self.l1 = nn.Linear(100, 100)
406                self.l2 = ToyModule(forward_inputs)
407                self.forward_inputs = forward_inputs
408
409            def forward(self, x: torch.Tensor) -> torch.Tensor:
410                self.forward_inputs["model_input_x"] = x
411                y = torch.ones(
412                    2, 100, device="cuda", dtype=torch.float32
413                )  # external input
414                return self.l2(self.l1(x), y)
415
416        forward_inputs: Dict[str, torch.Tensor] = {}
417        model = ToyModel(forward_inputs).cuda()
418        x = torch.zeros(2, 100, device="cuda", dtype=torch.float32)
419        fully_shard(
420            model.l2,
421            mp_policy=MixedPrecisionPolicy(
422                param_dtype=torch.float16, cast_forward_inputs=enable_submodule_cast
423            ),
424        )
425        fully_shard(model, mp_policy=MixedPrecisionPolicy(param_dtype=torch.float16))
426        model(x).sum().backward()
427
428        # If we enable `model.l2` to cast (as default), then `l2_input_y` gets
429        # cast to fp16, and if we disable, then it says as fp32.
430        self.assertEqual(forward_inputs["model_input_x"].dtype, torch.float16)
431        self.assertEqual(forward_inputs["l2_input_x"].dtype, torch.float16)
432        self.assertEqual(
433            forward_inputs["l2_input_y"].dtype,
434            torch.float16 if enable_submodule_cast else torch.float32,
435        )
436
437    @skip_if_lt_x_gpu(1)
438    @requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
439    def test_norm_modules_bf16(self):
440        mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
441        self._test_norm_modules(mp_policy)
442
443    @skip_if_lt_x_gpu(1)
444    def test_norm_modules_fp16(self):
445        mp_policy = MixedPrecisionPolicy(param_dtype=torch.float16)
446        self._test_norm_modules(mp_policy)
447
448    def _test_norm_modules(self, mp_policy: MixedPrecisionPolicy):
449        def inner(model: nn.Module, x: torch.Tensor):
450            # Run forward and backward to check for no type mismatch errors
451            z = model(x)
452            self.assertEqual(z.dtype, mp_policy.param_dtype)
453            z.sum().backward()
454
455        # Layer norm
456        model = nn.Sequential(nn.Linear(32, 32), nn.LayerNorm(32), nn.Linear(32, 32))
457        for module in (model[0], model[1], model[2], model):
458            fully_shard(module, mp_policy=mp_policy)
459        inner(model, torch.randn((4, 32)))
460
461        # Batch norm 1D
462        model = nn.Sequential(nn.Linear(32, 32), nn.BatchNorm1d(32), nn.Linear(32, 32))
463        for module in (model[0], model[1], model[2], model):
464            fully_shard(module, mp_policy=mp_policy)
465        inner(model, torch.randn((4, 32)))
466
467        # Batch norm 2D: error in backward from buffer dtype mismatch
468        model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
469        for module in (model[0], model[1], model[2], model):
470            fully_shard(module, mp_policy=mp_policy)
471        with self.assertRaisesRegex(RuntimeError, "Expected running_mean to have type"):
472            # Errors in batch norm 2D backward
473            inner(model, torch.randn((3, 1, 9, 9)))
474
475        # Batch norm 2D: cast buffers down to lower precision
476        model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
477        for module in (model[0], model[1], model[2], model):
478            fully_shard(module, mp_policy=mp_policy)
479        # Casting batch norm buffers to the lower precision allows backward
480        model[1].running_mean = model[1].running_mean.to(mp_policy.param_dtype)
481        model[1].running_var = model[1].running_var.to(mp_policy.param_dtype)
482        inner(model, torch.randn((3, 1, 9, 9)))
483
484        # Batch norm 2D: use special mixed precision policy
485        model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
486        bn_mp_policy = MixedPrecisionPolicy(output_dtype=mp_policy.param_dtype)
487        fully_shard(model[1], mp_policy=bn_mp_policy)
488        for module in (model[0], model[2], model):
489            fully_shard(module, mp_policy=mp_policy)
490        inner(model, torch.randn((3, 1, 9, 9)))
491
492
493if __name__ == "__main__":
494    run_tests()
495