xref: /aosp_15_r20/external/pytorch/test/distributed/fsdp/test_fsdp_apply.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import sys
4
5import torch
6import torch.distributed as dist
7import torch.nn as nn
8from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
9from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
10from torch.testing._internal.common_fsdp import (
11    CUDAInitMode,
12    FSDPInitMode,
13    FSDPTest,
14    NestedWrappedModule,
15    TransformerWithSharedParams,
16)
17from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
18
19
20if not dist.is_available():
21    print("Distributed not available, skipping tests", file=sys.stderr)
22    sys.exit(0)
23
24if TEST_WITH_DEV_DBG_ASAN:
25    print(
26        "Skip dev-asan as torch + multiprocessing spawn have known issues",
27        file=sys.stderr,
28    )
29    sys.exit(0)
30
31
32class TestApply(FSDPTest):
33    @property
34    def world_size(self):
35        return 2
36
37    @torch.no_grad()
38    def _init_linear_weights(self, m):
39        if type(m) == nn.Linear:
40            m.weight.fill_(1.0)
41            m.bias.fill_(1.0)
42
43    def check_weights(self, fsdp, expected_tensor_fn, check):
44        with FSDP.summon_full_params(fsdp, recurse=True):
45            linear_modules = [
46                module for module in fsdp.modules() if type(module) == nn.Linear
47            ]
48            for module in linear_modules:
49                for param in module.parameters():
50                    expected = expected_tensor_fn(param)
51                    check(param, expected, f"Got {param} but expected {expected}")
52
53    def _check_apply(self, fsdp):
54        # Assert linear weights are not all 1.0
55        self.check_weights(
56            fsdp, lambda param: torch.empty_like(param).fill_(1.0), self.assertNotEqual
57        )
58
59        fsdp.apply(self._init_linear_weights)
60
61        # Ensure all weights are 1.0
62        self.check_weights(
63            fsdp, lambda param: torch.empty_like(param).fill_(1.0), self.assertEqual
64        )
65
66    @skip_if_lt_x_gpu(2)
67    def test_nested_module_apply(self):
68        """Tests that ``apply()`` modifies parameter values in-place on a
69        non-FSDP-root nested FSDP-wrapped model."""
70        nested_wrapped_module = NestedWrappedModule.init(
71            self.process_group,
72            FSDPInitMode.RECURSIVE,
73            CUDAInitMode.CUDA_AFTER,
74        )
75        self._check_apply(nested_wrapped_module)
76
77    @skip_if_lt_x_gpu(2)
78    def test_transformer_module_apply(self):
79        """Tests that ``apply()`` modifies parameter values in-place on an
80        FSDP-wrapped transformer model with shared parameters."""
81        transformer = TransformerWithSharedParams.init(
82            self.process_group,
83            FSDPInitMode.RECURSIVE,
84            CUDAInitMode.CUDA_AFTER,
85        )
86        self._check_apply(transformer)
87
88    @skip_if_lt_x_gpu(2)
89    def test_apply_in_summon_raises_error(self):
90        """Tests that calling ``apply()`` on an FSDP instance inside the
91        ``summon_full_params()`` context raises an error."""
92        transformer = TransformerWithSharedParams.init(
93            self.process_group,
94            FSDPInitMode.RECURSIVE,
95            CUDAInitMode.CUDA_AFTER,
96        )
97        with transformer.summon_full_params(transformer):
98            with self.assertRaisesRegex(ValueError, "expected to be in states"):
99                transformer.apply(self._init_linear_weights)
100
101
102if __name__ == "__main__":
103    run_tests()
104