xref: /aosp_15_r20/external/pytorch/test/distributed/fsdp/test_fsdp_meta.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import itertools
4import sys
5from typing import Union
6
7import torch
8import torch.distributed as dist
9import torch.nn as nn
10from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision
11from torch.distributed.fsdp.wrap import (
12    always_wrap_policy as always_wrap,
13    enable_wrap,
14    ModuleWrapPolicy,
15    wrap,
16)
17from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
18from torch.testing._internal.common_fsdp import FSDPTest
19from torch.testing._internal.common_utils import (
20    instantiate_parametrized_tests,
21    parametrize,
22    run_tests,
23    skip_but_pass_in_sandcastle_if,
24    TEST_WITH_DEV_DBG_ASAN,
25)
26
27
28_TORCHDISTX_AVAIL = True
29try:
30    from torchdistx import deferred_init
31except ImportError:
32    _TORCHDISTX_AVAIL = False
33
34
35if not dist.is_available():
36    print("Distributed not available, skipping tests", file=sys.stderr)
37    sys.exit(0)
38
39if TEST_WITH_DEV_DBG_ASAN:
40    print(
41        "Skip dev-asan as torch + multiprocessing spawn have known issues",
42        file=sys.stderr,
43    )
44    sys.exit(0)
45
46
47def _reset_params_if_meta(is_meta: bool, model: nn.Module):
48    # For torchdistX init, we don't need to call reset_params, as
49    # deferred_init(model).materialize() is equivalent to model().
50    if is_meta:
51        for module in model.modules():
52            # Assume that a module has `reset_parameters()` iff it has directly
53            # managed parameters or buffers
54            if hasattr(module, "reset_parameters"):
55                module.reset_parameters()
56
57
58class MyLinear(nn.Linear):
59    """
60    Linear layer with deterministic reset_parameters for testing.
61    """
62
63    def __init__(self, *args, **kwargs):
64        super().__init__(*args, **kwargs)
65
66    def reset_parameters(self, *args, **kwargs):
67        torch.manual_seed(42)
68        with torch.no_grad():
69            # Use an initialization method that depends on shape
70            torch.nn.init.xavier_uniform_(self.weight, 1.0)
71
72
73class MyBuffer(nn.Module):
74    def __init__(self, device: torch.device):
75        super().__init__()
76        self.buf = torch.nn.Buffer(torch.empty((3, 3), device=device))
77
78    def reset_parameters(self, *args, **kwargs):
79        torch.manual_seed(42)
80        # Use an initialization method that depends on shape
81        torch.nn.init.xavier_uniform_(self.buf, 0.5)
82
83
84class MyModel(nn.Module):
85    def __init__(self, device: torch.device):
86        super().__init__()
87        self.lin1 = MyLinear(2, 2, bias=False, device=device)
88        self.lin2 = MyLinear(2, 2, bias=False, device=device)
89        self.buf_mod = MyBuffer(device)
90
91    def forward(self, x):
92        return self.lin2(self.lin1(x))
93
94
95class NestedModel(nn.Module):
96    def __init__(self, device):
97        super().__init__()
98        self.lin1 = MyLinear(2, 2, bias=False, device=device)
99        self.lin1 = wrap(self.lin1)
100        self.lin2 = MyLinear(2, 2, bias=False, device=device)
101        self.l3 = MyModel(device=device)
102        self.l3 = wrap(self.l3)
103
104    def forward(self, x):
105        return self.l3(self.lin2(self.lin1(x)))
106
107
108def _init_with_reset_params(module: nn.Module):
109    """
110    to_empty + reset_parameters() init function example for modules
111    initialized with device="meta"
112    """
113    has_meta_states = any(
114        t.is_meta
115        for t in itertools.chain(
116            module.parameters(recurse=False), module.buffers(recurse=False)
117        )
118    )
119    if has_meta_states:
120        device = torch.device("cuda", torch.cuda.current_device())
121        module.to_empty(device=device, recurse=False)
122        module.reset_parameters()
123
124
125def _init_with_torchdistX(module: nn.Module):
126    """
127    torchdistX-based deferred module initialization function example
128    using ``materialize_module``.
129    """
130    assert _TORCHDISTX_AVAIL
131
132    def check_fn(k):
133        return not isinstance(k, FSDP)
134
135    deferred_init.materialize_module(module, check_fn=check_fn)
136
137
138class TestFSDPWithMetaDevice(FSDPTest):
139    @property
140    def world_size(self):
141        return 2
142
143    @property
144    def process_group(self):
145        return dist.distributed_c10d._get_default_group()
146
147    def _compare_fsdp(self, fsdp1, fsdp2):
148        with FSDP.summon_full_params(fsdp1):
149            with FSDP.summon_full_params(fsdp2):
150                for p1, p2 in zip(fsdp1.parameters(), fsdp2.parameters()):
151                    self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}")
152
153    def _test_simple_model_with_meta_device(self, meta_module_fn, init_fn=None):
154        # Create model on meta device and wrap with FSDP.
155        model = meta_module_fn()
156        is_meta = next(model.parameters()).is_meta
157        fsdp_meta = FSDP(
158            model,
159            auto_wrap_policy=always_wrap,
160            param_init_fn=init_fn,
161        )
162
163        meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3)
164
165        # Test to make sure it is the same model parameters as regular FSDP
166        # approach.
167        regular = MyModel(device="cuda")
168        _reset_params_if_meta(is_meta, regular)
169        fsdp_regular = FSDP(regular, auto_wrap_policy=always_wrap)
170        regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3)
171
172        self._compare_fsdp(fsdp_meta, fsdp_regular)
173        inp = torch.randn(10, 2, device="cuda")
174        fsdp_meta(inp).sum().backward()
175        fsdp_regular(inp).sum().backward()
176        meta_opt.step()
177        regular_opt.step()
178        self._compare_fsdp(fsdp_meta, fsdp_regular)
179
180        # Test that meta init works if all submodules are contained in only a
181        # single FSDP unit.
182        model = meta_module_fn()
183        fsdp_meta = FSDP(model, param_init_fn=init_fn)
184        meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3)
185        regular = MyModel(device="cuda")
186        _reset_params_if_meta(is_meta, regular)
187        fsdp_regular = FSDP(regular, auto_wrap_policy=always_wrap)
188        regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3)
189
190        # Run a forward + backward pass + optimizer step
191        fsdp_meta(inp).sum().backward()
192        fsdp_regular(inp).sum().backward()
193        meta_opt.step()
194        regular_opt.step()
195        self._compare_fsdp(fsdp_meta, fsdp_regular)
196
197    @skip_if_lt_x_gpu(2)
198    def test_simple_model_with_meta_device_reset_params(self):
199        def meta_module_fn():
200            return MyModel(device="meta")
201
202        self._test_simple_model_with_meta_device(
203            meta_module_fn, _init_with_reset_params
204        )
205
206    @skip_if_lt_x_gpu(2)
207    def test_simple_model_with_meta_device_default_init(self):
208        def meta_module_fn():
209            return MyModel(device="meta")
210
211        self._test_simple_model_with_meta_device(meta_module_fn)
212
213    @skip_if_lt_x_gpu(2)
214    @skip_but_pass_in_sandcastle_if(
215        not _TORCHDISTX_AVAIL,
216        "Test requires torchdistX: https://github.com/pytorch/torchdistX",
217    )
218    def test_simple_model_with_torchdistX_default_init(self):
219        def meta_module_fn():
220            return deferred_init.deferred_init(MyModel, device="cuda")
221
222        self._test_simple_model_with_meta_device(meta_module_fn)
223
224    @skip_if_lt_x_gpu(2)
225    @skip_but_pass_in_sandcastle_if(
226        not _TORCHDISTX_AVAIL,
227        "Test requires torchdistX: https://github.com/pytorch/torchdistX",
228    )
229    def test_simple_model_with_torchdistX_init_fn(self):
230        def meta_module_fn():
231            return deferred_init.deferred_init(MyModel, device="cuda")
232
233        self._test_simple_model_with_meta_device(
234            meta_module_fn, init_fn=_init_with_torchdistX
235        )
236
237    def _test_nested_model_with_meta_device(
238        self, auto_wrap, meta_module_fn, init_fn=None
239    ):
240        if auto_wrap:
241            module = meta_module_fn()
242            is_meta = (
243                next(module.parameters()).is_meta or next(module.buffers()).is_meta
244            )
245            fsdp_meta = FSDP(
246                module,
247                auto_wrap_policy=always_wrap,
248                param_init_fn=init_fn,
249            )
250            meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3)
251            module_regular = NestedModel(device="cuda")
252            _reset_params_if_meta(is_meta, module_regular)
253            fsdp_regular = FSDP(
254                module_regular,
255                auto_wrap_policy=always_wrap,
256            )
257            regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3)
258        else:
259            with enable_wrap(
260                wrapper_cls=FSDP,
261                param_init_fn=init_fn,
262            ):
263                module = meta_module_fn()
264                is_meta = next(module.parameters()).is_meta
265                # Non FSDP modules will still be initialized because they bubble up
266                # to be part of a larger FSDP unit.
267                fsdp_meta = wrap(module)
268                meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3)
269
270            # Init and reset parameters before wrapping so that reset_params
271            # matches up with meta device's initialization.
272            module_regular = NestedModel(device="cuda")
273            _reset_params_if_meta(is_meta, module_regular)
274            with enable_wrap(wrapper_cls=FSDP):
275                module_regular.lin1 = wrap(module_regular.lin1)
276                module_regular.l3 = wrap(module_regular.l3)
277                fsdp_regular = wrap(module_regular)
278                regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3)
279
280        # Compare it before training
281        self._compare_fsdp(fsdp_meta, fsdp_regular)
282        inp = torch.randn(10, 2, device="cuda")
283        fsdp_meta(inp).sum().backward()
284        fsdp_regular(inp).sum().backward()
285        meta_opt.step()
286        regular_opt.step()
287        self._compare_fsdp(fsdp_meta, fsdp_regular)
288
289    @skip_if_lt_x_gpu(2)
290    @parametrize("auto_wrap", [True, False])
291    def test_nested_model_with_meta_device_reset_params(self, auto_wrap):
292        def meta_module_fn():
293            return NestedModel(device="meta")
294
295        self._test_nested_model_with_meta_device(
296            auto_wrap=auto_wrap,
297            meta_module_fn=meta_module_fn,
298            init_fn=_init_with_reset_params,
299        )
300
301    @skip_if_lt_x_gpu(2)
302    @parametrize("auto_wrap", [True, False])
303    def test_nested_model_with_meta_device_default_init(self, auto_wrap):
304        def meta_module_fn():
305            return NestedModel(device="meta")
306
307        self._test_nested_model_with_meta_device(
308            auto_wrap=auto_wrap,
309            meta_module_fn=meta_module_fn,
310        )
311
312    @skip_if_lt_x_gpu(2)
313    @skip_but_pass_in_sandcastle_if(
314        not _TORCHDISTX_AVAIL,
315        "Test requires torchdistX: https://github.com/pytorch/torchdistX",
316    )
317    @parametrize("auto_wrap", [True, False])
318    def test_nested_model_with_torchdistX_default_init(self, auto_wrap):
319        def meta_module_fn():
320            return deferred_init.deferred_init(NestedModel, device="cuda")
321
322        self._test_nested_model_with_meta_device(
323            auto_wrap=auto_wrap, meta_module_fn=meta_module_fn
324        )
325
326    @skip_if_lt_x_gpu(2)
327    @skip_but_pass_in_sandcastle_if(
328        not _TORCHDISTX_AVAIL,
329        "Test requires torchdistX: https://github.com/pytorch/torchdistX",
330    )
331    @parametrize("auto_wrap", [True, False])
332    def test_nested_model_with_torchdistX_init_fn(self, auto_wrap):
333        def meta_module_fn():
334            return deferred_init.deferred_init(NestedModel, device="cuda")
335
336        self._test_nested_model_with_meta_device(
337            auto_wrap=auto_wrap,
338            meta_module_fn=meta_module_fn,
339            init_fn=_init_with_torchdistX,
340        )
341
342    def _test_bad_arg(self, meta_module_fn):
343        mod = meta_module_fn()
344        with self.assertRaisesRegex(ValueError, "to be callable"):
345            FSDP(mod, param_init_fn=42)
346
347    @skip_if_lt_x_gpu(2)
348    @skip_but_pass_in_sandcastle_if(
349        not _TORCHDISTX_AVAIL,
350        "Test requires torchdistX: https://github.com/pytorch/torchdistX",
351    )
352    def test_bad_arg_torchdistx(self):
353        def meta_module_fn():
354            return deferred_init.deferred_init(NestedModel, "cuda")
355
356        self._test_bad_arg(meta_module_fn)
357
358    @skip_if_lt_x_gpu(2)
359    def test_bad_arg_meta(self):
360        def meta_module_fn():
361            return NestedModel(device="meta")
362
363        self._test_bad_arg(meta_module_fn)
364
365    @skip_if_lt_x_gpu(2)
366    def test_meta_device_with_mixed_precision(self):
367        """
368        Tests meta device initialization with a ``param_init_fn`` when
369        specifying mixed precision with ``param_dtype=torch.float32``.
370        """
371
372        class FakeLinear(nn.Module):
373            def __init__(
374                self, in_dim: int, out_dim: int, device: Union[torch.device, str]
375            ) -> None:
376                super().__init__()
377                self.weight = nn.Parameter(
378                    torch.randn((in_dim, out_dim), device=device)
379                )
380
381            def forward(self, x: torch.Tensor) -> torch.Tensor:
382                return x @ self.weight
383
384        class Model(nn.Module):
385            def __init__(self) -> None:
386                super().__init__()
387                self.lin1 = nn.Linear(5, 5, device="meta")
388                self.lin2 = FakeLinear(5, 5, device="meta")
389                self.relu = nn.ReLU()
390
391            def forward(self, x: torch.Tensor) -> torch.Tensor:
392                return self.lin2(self.relu(self.lin1(x)))
393
394            def _module_init_fn(self, module: nn.Module):
395                if isinstance(module, nn.Linear):
396                    torch.nn.init.normal_(module.weight, mean=0.0, std=0.1)
397                    if module.bias is not None:
398                        torch.nn.init.zeros_(module.bias)
399
400        def _param_init_fn(module: nn.Module) -> None:
401            # TODO: `module.to_empty()` is not generally correct for meta
402            # device initialization.
403            # https://github.com/pytorch/pytorch/issues/90465
404            module.to_empty(device=torch.device("cuda"))
405            module.apply(model._module_init_fn)
406
407        model = Model()
408        # Wrap `lin1` and the top level `model` to create nested FSDP instances
409        # where each instance has parameters
410        FSDP(
411            model,
412            auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
413            mixed_precision=MixedPrecision(
414                param_dtype=torch.float32, reduce_dtype=torch.float16
415            ),
416            param_init_fn=_param_init_fn,
417            device_id=torch.cuda.current_device(),
418        )
419
420
421instantiate_parametrized_tests(TestFSDPWithMetaDevice)
422
423if __name__ == "__main__":
424    run_tests()
425