xref: /aosp_15_r20/external/pytorch/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import time
4from dataclasses import dataclass, field
5from enum import auto, Enum
6from functools import partial
7from io import BytesIO
8from typing import Any, Dict, List
9
10import torch
11import torch.distributed as dist
12import torch.distributed.checkpoint as DCP
13import torch.distributed.checkpoint.state_dict_saver as saver
14import torch.nn as nn
15import torch.nn.functional as F
16from torch.distributed._tensor.device_mesh import init_device_mesh
17from torch.distributed.checkpoint.state_dict import (
18    _patch_model_state_dict,
19    _patch_optimizer_state_dict,
20    get_model_state_dict,
21    get_optimizer_state_dict,
22    get_state_dict,
23    set_state_dict,
24)
25from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys
26from torch.distributed.checkpoint.utils import CheckpointException
27from torch.distributed.distributed_c10d import ReduceOp
28from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
29from torch.distributed.fsdp.api import ShardingStrategy
30from torch.distributed.tensor.parallel import (
31    ColwiseParallel,
32    parallelize_module,
33    RowwiseParallel,
34)
35from torch.nn.parallel import DistributedDataParallel
36from torch.testing._internal.common_utils import (
37    instantiate_parametrized_tests,
38    parametrize,
39    run_tests,
40)
41from torch.testing._internal.distributed._tensor.common_dtensor import (
42    DTensorTestBase,
43    skip_if_lt_x_gpu,
44    with_comms,
45)
46from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
47from torch.testing._internal.distributed.common_state_dict import VerifyStateDictMixin
48
49
50# Simple and boring model
51class TestDummyModel(torch.nn.Module):
52    def __init__(self) -> None:
53        super().__init__()
54        torch.manual_seed(0)
55        self.net1 = nn.Linear(8, 16)
56        self.net2 = nn.Linear(16, 32)
57        self.net3 = nn.Linear(32, 64)
58        self.net4 = nn.Linear(64, 8)
59
60    def forward(self, x):
61        x = F.relu(self.net1(x))
62        x = F.relu(self.net2(x))
63        x = F.relu(self.net3(x))
64        x = F.relu(self.net4(x))
65        return x
66
67    def get_input(self):
68        return torch.rand(8, 8, device="cuda")
69
70
71class TestStatefulObj:
72    def __init__(self) -> None:
73        self.data = torch.rand(10, 10, device="cuda")
74
75    def state_dict(self):
76        return {"data": self.data}
77
78    def load_state_dict(self, state_dict):
79        self.data = state_dict["data"]
80
81    def __eq__(self, other):
82        return torch.equal(self.data, other.data)
83
84
85class ModelType(Enum):
86    FSDP = auto()
87    HSDP = auto()
88    FSDP_TP = auto()
89    DDP = auto()
90    NONE = auto()  # no parallelization
91
92
93@dataclass
94class TestTrainState:
95    step: int = 0
96    current_loss: float = -1
97    losses: List[float] = field(default_factory=list)
98
99    def state_dict(self) -> Dict[str, Any]:
100        loss_bytes = BytesIO()
101        torch.save(self.losses, loss_bytes)
102        return {
103            "step": torch.tensor(self.step, dtype=torch.int32),
104            "current_loss": torch.tensor(self.current_loss, dtype=torch.float32),
105            "losses": loss_bytes,
106        }
107
108    def load_state_dict(self, state_dict) -> None:
109        self.step = state_dict["step"].item()
110        self.current_loss = state_dict["current_loss"].item()
111        state_dict["losses"].seek(0)
112        self.losses = torch.load(state_dict["losses"])
113
114    def __eq__(self, other):
115        return (
116            self.step == other.step
117            and self.current_loss == other.current_loss
118            and self.losses == other.losses
119        )
120
121
122def _train(model, optim, train_steps=1):
123    torch.manual_seed(0)
124    loss = None
125
126    train_state = TestTrainState()
127
128    for _ in range(train_steps):
129        loss = model(model.get_input()).sum()
130        loss.backward()
131
132        # We usually sync the loss across dp ranks in real training.
133        # This is just simulating for testing purpose.
134        train_state.step += 1
135        train_state.current_loss = torch.rand(1).item()
136        train_state.losses.append(train_state.current_loss)
137
138        optim.step()
139        optim.zero_grad()
140
141    return loss, train_state
142
143
144class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
145    @property
146    def backend(self):
147        return "cpu:gloo,cuda:nccl"
148
149    def _create_model(self, compile, model_type, state_dict_options=None):
150        dummy_model = TestDummyModel().cuda()
151
152        assert model_type in ModelType, f"{model_type} is not supported."
153        if model_type == ModelType.FSDP:
154            device_mesh = init_device_mesh(self.device_type, (self.world_size,))
155            model = FSDP(
156                dummy_model,
157                device_mesh=device_mesh,
158                use_orig_params=True,
159            )
160        elif model_type == ModelType.HSDP:
161            device_mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
162            model = FSDP(
163                dummy_model,
164                device_mesh=device_mesh,
165                use_orig_params=True,
166                sharding_strategy=ShardingStrategy.HYBRID_SHARD,
167            )
168        elif model_type == ModelType.FSDP_TP:
169            mesh_2d = init_device_mesh(
170                self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
171            )
172            tp_mesh = mesh_2d["tp"]
173            dp_mesh = mesh_2d["dp"]
174            parallelize_plan = {
175                "net1": ColwiseParallel(),
176                "net2": RowwiseParallel(),
177            }
178            model = parallelize_module(dummy_model, tp_mesh, parallelize_plan)
179            model = FSDP(model, device_mesh=dp_mesh, use_orig_params=True)
180        elif model_type == ModelType.DDP:
181            model = DistributedDataParallel(dummy_model)
182            model.get_input = partial(TestDummyModel.get_input, model)
183        else:
184            model = dummy_model
185
186        if compile:
187            # TODO: enable dynamic=True when dynamic shape support is enabled.
188            # model = torch.compile(model)
189            model = torch.compile(model, dynamic=False)
190
191        optim = self._optim(model)
192        if model_type is not ModelType.NONE:
193            _patch_model_state_dict(model, options=state_dict_options)
194            _patch_optimizer_state_dict(
195                model, optimizers=optim, options=state_dict_options
196            )
197
198        return model, optim
199
200    def _optim(self, model):
201        return torch.optim.Adam(model.parameters(), lr=0.1)
202
203    @with_comms
204    @skip_if_lt_x_gpu(4)
205    @with_temp_dir
206    @parametrize("compile", [True, False])
207    # TODO: Previously PairwiseParallel does not shard properly, passing ModelType.FSDP_TP test where it
208    # should have failed. Disabling the failed test temporarily to unblock the deprecation of PairwiseParallel.
209    @parametrize("model_type", [ModelType.FSDP, ModelType.HSDP, ModelType.DDP])
210    def test_e2e(self, compile, model_type):
211        self._run_e2e_test(compile, model_type)
212
213    @with_comms
214    @skip_if_lt_x_gpu(4)
215    @with_temp_dir
216    @parametrize("cache_staged_state_dict", [False, True])
217    def test_e2e_async_cached(self, cache_staged_state_dict):
218        self._run_e2e_test(
219            compile=False,
220            model_type=ModelType.FSDP,
221            async_op=True,
222            cache_staged_state_dict=cache_staged_state_dict,
223        )
224
225    def _run_e2e_test(
226        self, compile, model_type, async_op=False, cache_staged_state_dict=False
227    ):
228        model, optim = self._create_model(compile, ModelType.NONE)
229        _train(model, optim, train_steps=2)
230
231        dist_model, dist_optim = self._create_model(compile, model_type)
232        _, original_train_state = _train(dist_model, dist_optim, train_steps=2)
233
234        original_stateful_obj = TestStatefulObj()  # tests arbitrary saving/loading
235        sd = {
236            "model": dist_model,
237            "optimizer": dist_optim,
238            "s": original_stateful_obj,
239            "train_state": original_train_state,
240        }
241
242        if async_op:
243            writer = DCP.FileSystemWriter(
244                self.temp_dir, cache_staged_state_dict=cache_staged_state_dict
245            )
246            f = saver.async_save(sd, storage_writer=writer)
247            t = time.monotonic()
248            while not f.done():
249                time.sleep(1)
250                print(f"still waiting... {time.monotonic() - t}")
251
252            f.result()
253        else:
254            DCP.save(sd, checkpoint_id=self.temp_dir)
255
256        loaded_stateful_obj = TestStatefulObj()
257        loaded_train_state = TestTrainState()
258        dist_model, dist_optim = self._create_model(compile, model_type)
259
260        DCP.load(
261            state_dict={
262                "model": dist_model,
263                "optimizer": dist_optim,
264                "s": loaded_stateful_obj,
265                "train_state": loaded_train_state,
266            },
267            checkpoint_id=self.temp_dir,
268        )
269
270        self.assertEqual(original_stateful_obj, loaded_stateful_obj)
271        self.assertEqual(original_train_state, loaded_train_state)
272
273        # train one more step on both models
274        loss, _ = _train(model, optim, train_steps=1)
275        dist_loss, _ = _train(dist_model, dist_optim, train_steps=1)
276        self.assertEqual(loss, dist_loss)
277
278        dist_msd, dist_osd = get_state_dict(dist_model, optimizers=dist_optim)
279        model_sd, optim_sd = get_state_dict(model, optimizers=optim)
280
281        self._verify_msd(model_sd, dist_msd)
282        self._verify_osd_by_load(model, optim, self._optim(model), dist_osd)
283
284    @with_comms
285    @with_temp_dir
286    @skip_if_lt_x_gpu(4)
287    def test_different_ordered_state_dict_keys(self):
288        """Tests that the order of keys in the state dict does not matter when loading
289        If order was not accounted for, the following test would cause a deadlock.
290        """
291
292        world_size = self.world_size
293
294        class Foo:
295            def state_dict(self):
296                return {}
297
298            def load_state_dict(self, state_dict):
299                tl = [
300                    torch.ones(2, dtype=torch.int64, device="cuda")
301                    for _ in range(world_size)
302                ]
303                t = (
304                    torch.arange(2, dtype=torch.int64, device="cuda")
305                    + 1
306                    + 2 * dist.get_rank()
307                )
308                dist.all_gather(tl, t, async_op=False)
309
310        class Bar:
311            def state_dict(self):
312                return {}
313
314            def load_state_dict(self, state_dict):
315                tensor = (
316                    torch.arange(2, dtype=torch.int64, device="cuda")
317                    + 1
318                    + 2 * dist.get_rank()
319                )
320                dist.all_reduce(tensor, op=ReduceOp.SUM)
321
322        if self.rank == 0:
323            sd = {
324                "A": Foo(),
325                "B": Bar(),
326            }
327        else:
328            sd = {
329                "B": Bar(),
330                "A": Foo(),
331            }
332
333        DCP.save(sd, checkpoint_id=self.temp_dir)
334        DCP.load(sd, checkpoint_id=self.temp_dir)
335
336    @with_temp_dir
337    def test_no_dist(self):
338        # since comm's are not initialized in this method, `no_dist`
339        # is assumed False
340        DCP.save({}, checkpoint_id=self.temp_dir)
341        DCP.load({}, checkpoint_id=self.temp_dir)
342
343    @with_comms
344    @skip_if_lt_x_gpu(4)
345    @with_temp_dir
346    def test_partial_load(self):
347        model, optim = self._create_model(compile=False, model_type=ModelType.NONE)
348        _train(model, optim, train_steps=2)
349
350        dist_model, dist_optim = self._create_model(
351            compile=False, model_type=ModelType.FSDP
352        )
353        _train(dist_model, dist_optim, train_steps=2)
354
355        DCP.save(
356            {"model": dist_model, "optimizer": dist_optim}, checkpoint_id=self.temp_dir
357        )
358
359        dist_model, _ = self._create_model(compile=False, model_type=ModelType.FSDP)
360        DCP.load({"model": dist_model}, checkpoint_id=self.temp_dir)
361
362        dist_msd = get_model_state_dict(dist_model)
363        model_sd = get_model_state_dict(model)
364        self._verify_msd(model_sd, dist_msd)
365
366        # another way
367        loaded_model_sd = _load_state_dict_from_keys(
368            "model", checkpoint_id=self.temp_dir
369        )["model"]
370        self._verify_msd(model_sd, loaded_model_sd, offload_to_cpu=True)
371
372        loaded_optim_state = _load_state_dict_from_keys(
373            "optimizer.state", checkpoint_id=self.temp_dir
374        )["optimizer"]["state"]
375        self.assertNotIn("param_groups", loaded_optim_state)
376        for k, v in dist_optim.state_dict()["state"].items():
377            for optim_key in ["exp_avg", "exp_avg_sq", "step"]:
378                self._compare_tensor(
379                    loaded_optim_state[k][optim_key], v[optim_key], offload_to_cpu=True
380                )
381
382    @with_comms
383    @skip_if_lt_x_gpu(4)
384    @with_temp_dir
385    def test_overwrite(self):
386        t1, t2 = torch.randn(10), torch.randn(10)
387        DCP.save({"random": t1}, checkpoint_id=self.temp_dir)
388        DCP.save(
389            {"random": t2},
390            storage_writer=DCP.FileSystemWriter(self.temp_dir, overwrite=True),
391        )
392
393        sd = {"random": torch.zeros(10)}
394        DCP.load(sd, checkpoint_id=self.temp_dir)
395
396        self.assertTrue(torch.allclose(sd["random"], t2))
397
398        with self.assertRaisesRegex(
399            CheckpointException, ".*Checkpoint already exists.*"
400        ):
401            DCP.save(
402                {"random": t2},
403                storage_writer=DCP.FileSystemWriter(self.temp_dir, overwrite=False),
404            )
405
406
407class TestNoCPU(DTensorTestBase):
408    @property
409    def backend(self):
410        return "nccl"
411
412    @with_comms
413    def test_no_cpu(self):
414        with self.assertRaisesRegex(
415            AssertionError, r"A CPU backend must be enabled for async save;.*?"
416        ):
417            f = saver.async_save({})
418            f.result()
419
420
421class TestInitStateDict(DTensorTestBase):
422    @with_temp_dir
423    def test_init_state_dict(self):
424        temp_dir = self.temp_dir
425        model = TestDummyModel()
426        optim = torch.optim.Adam(model.parameters(), lr=0.1)
427
428        state_dict_to_save = {
429            "model": get_model_state_dict(model),
430            "optimizer": get_optimizer_state_dict(model, optim),
431        }
432        DCP.save(state_dict_to_save, checkpoint_id=temp_dir)
433
434        torch.manual_seed(0)
435        model_2 = TestDummyModel()
436        # Changing the learning rate for optimizer, which is not a tensor.
437        optim_2 = torch.optim.Adam(model_2.parameters(), lr=0.2)
438
439        msd = get_model_state_dict(model_2)
440        osd = get_optimizer_state_dict(model_2, optim_2)
441
442        state_dict_to_load = {"model": msd, "optimizer": osd}
443        DCP.load(state_dict_to_load, checkpoint_id=temp_dir)
444
445        # We need to check that the two variables point to the same object in memory,
446        # since we claim DCP is in-place loading.
447        self.assertTrue(msd is state_dict_to_load["model"])
448        self.assertTrue(osd is state_dict_to_load["optimizer"])
449
450        # set_state_dict calls load_state_dict for model and optimizer.
451        # so we should see the optim_2.param_groups learning rate is 0.1 instead of 0.2 now.
452        set_state_dict(
453            model_2,
454            optim_2,
455            model_state_dict=state_dict_to_load["model"],
456            optim_state_dict=state_dict_to_load["optimizer"],
457        )
458        self.assertEqual(msd, get_model_state_dict(model_2))
459        self.assertEqual(osd, get_optimizer_state_dict(model_2, optim_2))
460        self.assertEqual(optim_2.param_groups[0]["lr"], 0.1)
461
462
463instantiate_parametrized_tests(TestE2ESaveAndLoad)
464if __name__ == "__main__":
465    run_tests()
466