xref: /aosp_15_r20/external/pytorch/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3
4"""
5The following example demonstrates how to use Pytorch Distributed Checkpoint to save a FSDP model.
6
7This is the current recommended way to checkpoint FSDP.
8torch.save() and torch.load() is not recommended when checkpointing sharded models.
9"""
10
11import os
12import shutil
13
14import torch
15import torch.distributed as dist
16import torch.distributed.checkpoint as dist_cp
17import torch.multiprocessing as mp
18from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
19from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
20from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
21
22
23CHECKPOINT_DIR = f"/scratch/{os.environ['LOGNAME']}/checkpoint"
24
25
26def opt_at(opt, idx):
27    return list(opt.state.values())[idx]
28
29
30def init_model():
31    model = FSDP(torch.nn.Linear(4, 4).cuda(dist.get_rank()))
32    optim = torch.optim.Adam(model.parameters(), lr=0.1)
33    model(torch.rand(4, 4)).sum().backward()
34    optim.step()
35
36    return model, optim
37
38
39def print_params(stage, model_1, model_2, optim_1, optim_2):
40    with FSDP.summon_full_params(model_1):
41        with FSDP.summon_full_params(model_2):
42            print(
43                f"{stage} --- rank: {dist.get_rank()}\n"
44                f"model.weight: {model_1.weight}\n"
45                f"model_2.weight:{model_2.weight}\n"
46                f"model.bias: {model_1.bias}\n"
47                f"model_2.bias: {model_2.bias}\n"
48            )
49
50    print(
51        f"{stage} --- rank: {dist.get_rank()}\n"
52        f"optim exp_avg:{opt_at(optim_1, 0)['exp_avg']}\n"
53        f"optim_2 exp_avg:{opt_at(optim_2, 0)['exp_avg']}\n"
54        f"optim exp_avg_sq:{opt_at(optim_1, 0)['exp_avg_sq']}\n"
55        f"optim_2 exp_avg_sq:{opt_at(optim_2, 0)['exp_avg_sq']}\n"
56    )
57
58
59def run_fsdp_checkpoint_example(rank, world_size):
60    # Set up world pg
61    os.environ["MASTER_ADDR"] = "localhost"
62    os.environ["MASTER_PORT"] = "12355"
63
64    # Initialize the process group
65    dist.init_process_group("cpu:gloo,cuda:nccl", rank=rank, world_size=world_size)
66    torch.cuda.set_device(rank)
67
68    # Create a model
69    model_1, optim_1 = init_model()
70
71    # Save the model to CHECKPOINT_DIR
72    with FSDP.state_dict_type(model_1, StateDictType.SHARDED_STATE_DICT):
73        state_dict = {
74            "model": model_1.state_dict(),
75            "optim": FSDP.optim_state_dict(model_1, optim_1),
76        }
77
78        dist_cp.save_state_dict(
79            state_dict=state_dict,
80            storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
81        )
82
83    # Create a second model
84    model_2, optim_2 = init_model()
85
86    # Print the model parameters for both models.
87    # Before loading, the parameters should be different.
88    print_params("Before loading", model_1, model_2, optim_1, optim_2)
89
90    # Load model_2 with parameters saved in CHECKPOINT_DIR
91    with FSDP.state_dict_type(model_2, StateDictType.SHARDED_STATE_DICT):
92        state_dict = {
93            "model": model_2.state_dict(),
94            # cannot load the optimizer state_dict together with the model state_dict
95        }
96
97        dist_cp.load_state_dict(
98            state_dict=state_dict,
99            storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
100        )
101        model_2.load_state_dict(state_dict["model"])
102
103        optim_state = load_sharded_optimizer_state_dict(
104            model_state_dict=state_dict["model"],
105            optimizer_key="optim",
106            storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
107        )
108
109        flattened_osd = FSDP.optim_state_dict_to_load(
110            model_2, optim_2, optim_state["optim"]
111        )
112        optim_2.load_state_dict(flattened_osd)
113
114    # Print the model parameters for both models.
115    # After loading, the parameters should be the same.
116    print_params("After loading", model_1, model_2, optim_1, optim_2)
117
118    # Shut down world pg
119    dist.destroy_process_group()
120
121
122if __name__ == "__main__":
123    world_size = torch.cuda.device_count()
124    print(f"Running fsdp checkpoint example on {world_size} devices.")
125    shutil.rmtree(CHECKPOINT_DIR, ignore_errors=True)
126    mp.spawn(
127        run_fsdp_checkpoint_example,
128        args=(world_size,),
129        nprocs=world_size,
130        join=True,
131    )
132