xref: /aosp_15_r20/external/pytorch/test/distributed/fsdp/test_fsdp_checkpoint.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import contextlib
4import sys
5from copy import deepcopy
6from functools import partial
7
8import torch
9import torch.distributed as dist
10import torch.nn as nn
11from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
12    checkpoint_wrapper,
13    offload_wrapper,
14)
15from torch.distributed.fsdp import ShardingStrategy
16from torch.distributed.fsdp.fully_sharded_data_parallel import (
17    CPUOffload,
18    FullyShardedDataParallel as FSDP,
19)
20from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
21from torch.testing._internal.common_fsdp import _maybe_wrap_fsdp, FSDPTest
22from torch.testing._internal.common_utils import (
23    instantiate_parametrized_tests,
24    parametrize,
25    run_tests,
26    TEST_WITH_DEV_DBG_ASAN,
27)
28from torch.utils.checkpoint import checkpoint
29
30
31if not dist.is_available():
32    print("Distributed not available, skipping tests", file=sys.stderr)
33    sys.exit(0)
34
35if TEST_WITH_DEV_DBG_ASAN:
36    print(
37        "Skip dev-asan as torch + multiprocessing spawn have known issues",
38        file=sys.stderr,
39    )
40    sys.exit(0)
41
42
43_save_on_cpu_called = False
44
45
46def get_patched_save_on_cpu():
47    orig_save_on_cpu = (
48        torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu
49    )
50
51    def patched_save_on_cpu(*args, **kwargs):
52        global _save_on_cpu_called
53        _save_on_cpu_called = True
54        return orig_save_on_cpu(*args, **kwargs)
55
56    return patched_save_on_cpu
57
58
59@contextlib.contextmanager
60def patch_save_on_cpu(new_save_on_cpu):
61    orig_save_on_cpu = (
62        torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu
63    )
64    torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu = (
65        new_save_on_cpu
66    )
67    try:
68        yield
69    finally:
70        torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu = (
71            orig_save_on_cpu
72        )
73
74
75class TestFSDPCheckpoint(FSDPTest):
76    class SequentialModule(nn.Module):
77        def __init__(
78            self,
79            checkpoint_layer=False,
80            offload_activations=False,
81            wrap_fsdp=False,
82            *fsdp_args,
83            **fsdp_kwargs,
84        ):
85            torch.manual_seed(0)
86            torch.cuda.manual_seed(0)
87            super().__init__()
88            l1 = nn.Linear(3, 3).cuda()
89            l2 = nn.Linear(3, 3).cuda()
90            l3 = nn.Linear(3, 3).cuda()
91
92            if checkpoint_layer:
93                if offload_activations:
94                    ckpt_wrapper = offload_wrapper
95                else:
96                    ckpt_wrapper = checkpoint_wrapper
97
98                l1 = ckpt_wrapper(l1)
99                l2 = ckpt_wrapper(l2)
100                l3 = ckpt_wrapper(l3)
101
102            fsdp_wrapper = partial(
103                _maybe_wrap_fsdp, *fsdp_args, wrap_fsdp=wrap_fsdp, **fsdp_kwargs
104            )
105            self.ffn = nn.Sequential(
106                fsdp_wrapper(l1),
107                fsdp_wrapper(l2),
108                fsdp_wrapper(l3),
109            )
110
111        def forward(self, x):
112            return self.ffn(x)
113
114    def _verify_parity(self, losses, outputs, models):
115        assert losses
116        assert outputs
117        assert models
118
119        for l, o in zip(losses[1:], outputs[1:]):
120            self.assertEqual(losses[0], l)
121            self.assertEqual(outputs[0], o)
122
123        # Verify grads
124        ref_model = models[0]
125        ref_grads = [p.grad for p in ref_model.parameters()]
126        for m in models[1:]:
127            grads = [p.grad for p in m.parameters()]
128            for ref_g, g in zip(ref_grads, grads):
129                self.assertEqual(ref_g, g)
130
131    @skip_if_lt_x_gpu(2)
132    @parametrize(
133        "cpu_offload",
134        [CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
135    )
136    @parametrize("offload_activations", [True, False])
137    @parametrize("use_orig_params", [False, True])
138    def test_checkpoint_fsdp_wrapping(
139        self,
140        cpu_offload: CPUOffload,
141        offload_activations: bool,
142        use_orig_params: bool,
143    ):
144        # Test checkpoint(FSDP(layer1), FSDP(layer2), ....)
145        if offload_activations:
146            wrapper_to_use = offload_wrapper
147        else:
148            wrapper_to_use = checkpoint_wrapper
149
150        fsdp_kwargs = {"cpu_offload": cpu_offload, "use_orig_params": use_orig_params}
151        ckpt_sequential_wrapped_fsdp = wrapper_to_use(
152            TestFSDPCheckpoint.SequentialModule(
153                wrap_fsdp=True,
154                **fsdp_kwargs,
155            ),
156        )
157        # Test FSDP(checkpoint(layer1)), FSDP(checkpoint(layer2)), ....
158        inner_ckpt = TestFSDPCheckpoint.SequentialModule(
159            checkpoint_layer=True,
160            offload_activations=offload_activations,
161            wrap_fsdp=True,
162            **fsdp_kwargs,
163        )
164
165        baseline = TestFSDPCheckpoint.SequentialModule(
166            wrap_fsdp=True,
167            **fsdp_kwargs,
168        )
169
170        # note that reentrant-based checkpointing requires inputs to have grad
171        # flag set.
172        inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True)
173
174        global _save_on_cpu_called
175        models = [ckpt_sequential_wrapped_fsdp, inner_ckpt, baseline]
176        with patch_save_on_cpu(get_patched_save_on_cpu()):
177            for i in range(2):
178                losses = []
179                outputs = []
180                for m in models:
181                    check_offload = m != baseline and i == 0 and offload_activations
182                    if check_offload:
183                        self.assertFalse(_save_on_cpu_called)
184                    out = m(inp)
185                    if check_offload:
186                        self.assertTrue(_save_on_cpu_called)
187                        _save_on_cpu_called = False
188                    loss = out.sum()
189                    loss.backward()
190                    losses.append(loss)
191                    outputs.append(out)
192
193                self._verify_parity(losses, outputs, models)
194
195        dist.barrier()
196
197    @skip_if_lt_x_gpu(2)
198    @parametrize(
199        "cpu_offload",
200        [CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
201    )
202    @parametrize("offload_activations", [True, False])
203    @parametrize("use_orig_params", [False, True])
204    def test_basic_checkpoint_end_to_end(
205        self,
206        cpu_offload: CPUOffload,
207        offload_activations: bool,
208        use_orig_params: bool,
209    ):
210        fsdp_kwargs = {"cpu_offload": cpu_offload, "use_orig_params": use_orig_params}
211        global _save_on_cpu_called
212        with patch_save_on_cpu(get_patched_save_on_cpu()):
213            seq = TestFSDPCheckpoint.SequentialModule().to(torch.cuda.current_device())
214            # Runs FSDP with no checkpointing
215            fsdp_only_seq = FSDP(deepcopy(seq), **fsdp_kwargs)
216            # Runs checkpoint-wrapped FSDP
217            if offload_activations:
218                wrapper_to_use = offload_wrapper
219            else:
220                wrapper_to_use = checkpoint_wrapper
221
222            checkpointed_fsdp = wrapper_to_use(
223                FSDP(deepcopy(seq), **fsdp_kwargs),
224            )
225            # Runs FSDP-wrapped checkpointed module
226            fsdp_wrapped_checkpoint = FSDP(
227                wrapper_to_use(deepcopy(seq)),
228                **fsdp_kwargs,
229            )
230            # Runs FSDP with manual calls to checkpoint.
231            fsdp_call_checkpoint = FSDP(deepcopy(seq), **fsdp_kwargs)
232            # note that reentrant-based checkpointing requires inputs to have grad
233            # flag set.
234
235            inp = torch.randn(
236                10, 3, device=torch.cuda.current_device(), requires_grad=True
237            )
238
239            models = [
240                fsdp_only_seq,
241                checkpointed_fsdp,
242                fsdp_wrapped_checkpoint,
243                fsdp_call_checkpoint,
244            ]
245            # Ensure _save_on_cpu is not yet called
246            self.assertFalse(_save_on_cpu_called)
247            for i in range(6):
248                losses = []
249                outputs = []
250                for m in models:
251                    check_offload = (
252                        m != fsdp_only_seq and i == 0 and offload_activations
253                    )
254                    if m == fsdp_call_checkpoint:
255                        # _save_on_cpu should not be called yet
256                        self.assertFalse(_save_on_cpu_called)
257                        offload_ctx = (
258                            get_patched_save_on_cpu()(pin_memory=True)
259                            if offload_activations
260                            else contextlib.nullcontext()
261                        )
262                        with offload_ctx:
263                            out = checkpoint(m, inp, use_reentrant=True)
264                    else:
265                        # _save_on_cpu should not be called yet
266                        self.assertFalse(_save_on_cpu_called)
267                        out = m(inp)
268
269                    if check_offload:
270                        self.assertTrue(_save_on_cpu_called)
271                    loss = out.sum()
272                    loss.backward()
273                    losses.append(loss)
274                    outputs.append(out)
275                    _save_on_cpu_called = False
276
277                self._verify_parity(losses, outputs, models)
278
279        dist.barrier()
280
281
282instantiate_parametrized_tests(TestFSDPCheckpoint)
283
284
285class CheckpointModule(nn.Module):
286    def __init__(self, checkpoint: bool = False, use_reentrant: bool = True):
287        super().__init__()
288        self.seq = nn.Sequential(*[nn.Linear(100, 100) for _ in range(4)])
289        self.checkpoint = checkpoint
290        self.use_reentrant = use_reentrant
291
292    def forward(self, x):
293        return (
294            checkpoint(self.seq, x, use_reentrant=self.use_reentrant)
295            if self.checkpoint
296            else self.seq(x)
297        )
298
299
300class ModelWithCheckpointSubmodule(nn.Module):
301    def __init__(self, checkpoint: bool = False, use_reentrant: bool = True):
302        super().__init__()
303        self.l1 = nn.Linear(100, 100)
304        self.s1 = CheckpointModule(checkpoint, use_reentrant)
305        self.s2 = CheckpointModule(checkpoint, use_reentrant)
306        self.relu = nn.ReLU()
307        self.l2 = nn.Linear(100, 100)
308
309    def forward(self, x):
310        return self.l2(self.relu(self.s2(self.s1(self.l1(x)))))
311
312
313class TestModel(nn.Module):
314    def __init__(self, checkpoint: bool = False, use_reentrant: bool = True):
315        super().__init__()
316        self.l1 = nn.Linear(100, 100)
317        self.relu = nn.ReLU()
318        self.checkpoint1 = ModelWithCheckpointSubmodule(checkpoint, use_reentrant)
319        self.checkpoint2 = ModelWithCheckpointSubmodule(checkpoint, use_reentrant)
320        self.l2 = nn.Linear(100, 100)
321
322    def forward(self, x):
323        return self.l2(self.relu(self.checkpoint2(self.checkpoint1(self.l1(x)))))
324
325
326class TestFSDPCheckpointSubmodule(FSDPTest):
327    # TODO: grad value checks occasionally fails when use_reentrant = True
328    @skip_if_lt_x_gpu(2)
329    @parametrize("use_reentrant", [False])
330    def test_checkpoint_submodule(self, use_reentrant: bool):
331        model = TestModel(use_reentrant=use_reentrant).cuda()
332        model_ac = deepcopy(model)
333
334        for _, m in model_ac.named_modules():
335            if isinstance(m, CheckpointModule):
336                m.checkpoint = True
337
338        self.assertTrue(model_ac.checkpoint1.s1.checkpoint)
339        self.assertTrue(model_ac.checkpoint2.s2.checkpoint)
340
341        fsdp_kwargs = {
342            "device_id": torch.cuda.current_device(),
343            "sharding_strategy": ShardingStrategy.NO_SHARD,
344        }
345
346        # Wrap no checkpointing model submodules with FSDP
347        model.checkpoint1 = FSDP(module=model.checkpoint1, **fsdp_kwargs)
348        model.checkpoint2 = FSDP(module=model.checkpoint2, **fsdp_kwargs)
349
350        # Wrap checkpointing model submodules with FSDP
351        model_ac.checkpoint1 = FSDP(module=model_ac.checkpoint1, **fsdp_kwargs)
352        model_ac.checkpoint2 = FSDP(module=model_ac.checkpoint2, **fsdp_kwargs)
353
354        x = torch.randn(2, 100, device="cuda")
355
356        model(x).sum().backward()
357        model_ac(x).sum().backward()
358
359        for (n1, p1), (n2, p2) in zip(
360            model.named_parameters(), model_ac.named_parameters()
361        ):
362            self.assertEqual(n1, n2)
363            self.assertTrue(p1.grad.allclose(p2.grad))
364
365
366instantiate_parametrized_tests(TestFSDPCheckpointSubmodule)
367
368
369if __name__ == "__main__":
370    run_tests()
371