1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3 4""" 5The following example demonstrates how to use Pytorch Distributed Checkpoint to save a FSDP model. 6 7This is the current recommended way to checkpoint FSDP. 8torch.save() and torch.load() is not recommended when checkpointing sharded models. 9""" 10 11import os 12import shutil 13 14import torch 15import torch.distributed as dist 16import torch.distributed.checkpoint as dist_cp 17import torch.multiprocessing as mp 18from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict 19from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 20from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType 21 22 23CHECKPOINT_DIR = f"/scratch/{os.environ['LOGNAME']}/checkpoint" 24 25 26def opt_at(opt, idx): 27 return list(opt.state.values())[idx] 28 29 30def init_model(): 31 model = FSDP(torch.nn.Linear(4, 4).cuda(dist.get_rank())) 32 optim = torch.optim.Adam(model.parameters(), lr=0.1) 33 model(torch.rand(4, 4)).sum().backward() 34 optim.step() 35 36 return model, optim 37 38 39def print_params(stage, model_1, model_2, optim_1, optim_2): 40 with FSDP.summon_full_params(model_1): 41 with FSDP.summon_full_params(model_2): 42 print( 43 f"{stage} --- rank: {dist.get_rank()}\n" 44 f"model.weight: {model_1.weight}\n" 45 f"model_2.weight:{model_2.weight}\n" 46 f"model.bias: {model_1.bias}\n" 47 f"model_2.bias: {model_2.bias}\n" 48 ) 49 50 print( 51 f"{stage} --- rank: {dist.get_rank()}\n" 52 f"optim exp_avg:{opt_at(optim_1, 0)['exp_avg']}\n" 53 f"optim_2 exp_avg:{opt_at(optim_2, 0)['exp_avg']}\n" 54 f"optim exp_avg_sq:{opt_at(optim_1, 0)['exp_avg_sq']}\n" 55 f"optim_2 exp_avg_sq:{opt_at(optim_2, 0)['exp_avg_sq']}\n" 56 ) 57 58 59def run_fsdp_checkpoint_example(rank, world_size): 60 # Set up world pg 61 os.environ["MASTER_ADDR"] = "localhost" 62 os.environ["MASTER_PORT"] = "12355" 63 64 # Initialize the process group 65 dist.init_process_group("cpu:gloo,cuda:nccl", rank=rank, world_size=world_size) 66 torch.cuda.set_device(rank) 67 68 # Create a model 69 model_1, optim_1 = init_model() 70 71 # Save the model to CHECKPOINT_DIR 72 with FSDP.state_dict_type(model_1, StateDictType.SHARDED_STATE_DICT): 73 state_dict = { 74 "model": model_1.state_dict(), 75 "optim": FSDP.optim_state_dict(model_1, optim_1), 76 } 77 78 dist_cp.save_state_dict( 79 state_dict=state_dict, 80 storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR), 81 ) 82 83 # Create a second model 84 model_2, optim_2 = init_model() 85 86 # Print the model parameters for both models. 87 # Before loading, the parameters should be different. 88 print_params("Before loading", model_1, model_2, optim_1, optim_2) 89 90 # Load model_2 with parameters saved in CHECKPOINT_DIR 91 with FSDP.state_dict_type(model_2, StateDictType.SHARDED_STATE_DICT): 92 state_dict = { 93 "model": model_2.state_dict(), 94 # cannot load the optimizer state_dict together with the model state_dict 95 } 96 97 dist_cp.load_state_dict( 98 state_dict=state_dict, 99 storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), 100 ) 101 model_2.load_state_dict(state_dict["model"]) 102 103 optim_state = load_sharded_optimizer_state_dict( 104 model_state_dict=state_dict["model"], 105 optimizer_key="optim", 106 storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), 107 ) 108 109 flattened_osd = FSDP.optim_state_dict_to_load( 110 model_2, optim_2, optim_state["optim"] 111 ) 112 optim_2.load_state_dict(flattened_osd) 113 114 # Print the model parameters for both models. 115 # After loading, the parameters should be the same. 116 print_params("After loading", model_1, model_2, optim_1, optim_2) 117 118 # Shut down world pg 119 dist.destroy_process_group() 120 121 122if __name__ == "__main__": 123 world_size = torch.cuda.device_count() 124 print(f"Running fsdp checkpoint example on {world_size} devices.") 125 shutil.rmtree(CHECKPOINT_DIR, ignore_errors=True) 126 mp.spawn( 127 run_fsdp_checkpoint_example, 128 args=(world_size,), 129 nprocs=world_size, 130 join=True, 131 ) 132