1# Owner(s): ["oncall: distributed"] 2 3import time 4from dataclasses import dataclass, field 5from enum import auto, Enum 6from functools import partial 7from io import BytesIO 8from typing import Any, Dict, List 9 10import torch 11import torch.distributed as dist 12import torch.distributed.checkpoint as DCP 13import torch.distributed.checkpoint.state_dict_saver as saver 14import torch.nn as nn 15import torch.nn.functional as F 16from torch.distributed._tensor.device_mesh import init_device_mesh 17from torch.distributed.checkpoint.state_dict import ( 18 _patch_model_state_dict, 19 _patch_optimizer_state_dict, 20 get_model_state_dict, 21 get_optimizer_state_dict, 22 get_state_dict, 23 set_state_dict, 24) 25from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys 26from torch.distributed.checkpoint.utils import CheckpointException 27from torch.distributed.distributed_c10d import ReduceOp 28from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 29from torch.distributed.fsdp.api import ShardingStrategy 30from torch.distributed.tensor.parallel import ( 31 ColwiseParallel, 32 parallelize_module, 33 RowwiseParallel, 34) 35from torch.nn.parallel import DistributedDataParallel 36from torch.testing._internal.common_utils import ( 37 instantiate_parametrized_tests, 38 parametrize, 39 run_tests, 40) 41from torch.testing._internal.distributed._tensor.common_dtensor import ( 42 DTensorTestBase, 43 skip_if_lt_x_gpu, 44 with_comms, 45) 46from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir 47from torch.testing._internal.distributed.common_state_dict import VerifyStateDictMixin 48 49 50# Simple and boring model 51class TestDummyModel(torch.nn.Module): 52 def __init__(self) -> None: 53 super().__init__() 54 torch.manual_seed(0) 55 self.net1 = nn.Linear(8, 16) 56 self.net2 = nn.Linear(16, 32) 57 self.net3 = nn.Linear(32, 64) 58 self.net4 = nn.Linear(64, 8) 59 60 def forward(self, x): 61 x = F.relu(self.net1(x)) 62 x = F.relu(self.net2(x)) 63 x = F.relu(self.net3(x)) 64 x = F.relu(self.net4(x)) 65 return x 66 67 def get_input(self): 68 return torch.rand(8, 8, device="cuda") 69 70 71class TestStatefulObj: 72 def __init__(self) -> None: 73 self.data = torch.rand(10, 10, device="cuda") 74 75 def state_dict(self): 76 return {"data": self.data} 77 78 def load_state_dict(self, state_dict): 79 self.data = state_dict["data"] 80 81 def __eq__(self, other): 82 return torch.equal(self.data, other.data) 83 84 85class ModelType(Enum): 86 FSDP = auto() 87 HSDP = auto() 88 FSDP_TP = auto() 89 DDP = auto() 90 NONE = auto() # no parallelization 91 92 93@dataclass 94class TestTrainState: 95 step: int = 0 96 current_loss: float = -1 97 losses: List[float] = field(default_factory=list) 98 99 def state_dict(self) -> Dict[str, Any]: 100 loss_bytes = BytesIO() 101 torch.save(self.losses, loss_bytes) 102 return { 103 "step": torch.tensor(self.step, dtype=torch.int32), 104 "current_loss": torch.tensor(self.current_loss, dtype=torch.float32), 105 "losses": loss_bytes, 106 } 107 108 def load_state_dict(self, state_dict) -> None: 109 self.step = state_dict["step"].item() 110 self.current_loss = state_dict["current_loss"].item() 111 state_dict["losses"].seek(0) 112 self.losses = torch.load(state_dict["losses"]) 113 114 def __eq__(self, other): 115 return ( 116 self.step == other.step 117 and self.current_loss == other.current_loss 118 and self.losses == other.losses 119 ) 120 121 122def _train(model, optim, train_steps=1): 123 torch.manual_seed(0) 124 loss = None 125 126 train_state = TestTrainState() 127 128 for _ in range(train_steps): 129 loss = model(model.get_input()).sum() 130 loss.backward() 131 132 # We usually sync the loss across dp ranks in real training. 133 # This is just simulating for testing purpose. 134 train_state.step += 1 135 train_state.current_loss = torch.rand(1).item() 136 train_state.losses.append(train_state.current_loss) 137 138 optim.step() 139 optim.zero_grad() 140 141 return loss, train_state 142 143 144class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin): 145 @property 146 def backend(self): 147 return "cpu:gloo,cuda:nccl" 148 149 def _create_model(self, compile, model_type, state_dict_options=None): 150 dummy_model = TestDummyModel().cuda() 151 152 assert model_type in ModelType, f"{model_type} is not supported." 153 if model_type == ModelType.FSDP: 154 device_mesh = init_device_mesh(self.device_type, (self.world_size,)) 155 model = FSDP( 156 dummy_model, 157 device_mesh=device_mesh, 158 use_orig_params=True, 159 ) 160 elif model_type == ModelType.HSDP: 161 device_mesh = init_device_mesh(self.device_type, (2, self.world_size // 2)) 162 model = FSDP( 163 dummy_model, 164 device_mesh=device_mesh, 165 use_orig_params=True, 166 sharding_strategy=ShardingStrategy.HYBRID_SHARD, 167 ) 168 elif model_type == ModelType.FSDP_TP: 169 mesh_2d = init_device_mesh( 170 self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp") 171 ) 172 tp_mesh = mesh_2d["tp"] 173 dp_mesh = mesh_2d["dp"] 174 parallelize_plan = { 175 "net1": ColwiseParallel(), 176 "net2": RowwiseParallel(), 177 } 178 model = parallelize_module(dummy_model, tp_mesh, parallelize_plan) 179 model = FSDP(model, device_mesh=dp_mesh, use_orig_params=True) 180 elif model_type == ModelType.DDP: 181 model = DistributedDataParallel(dummy_model) 182 model.get_input = partial(TestDummyModel.get_input, model) 183 else: 184 model = dummy_model 185 186 if compile: 187 # TODO: enable dynamic=True when dynamic shape support is enabled. 188 # model = torch.compile(model) 189 model = torch.compile(model, dynamic=False) 190 191 optim = self._optim(model) 192 if model_type is not ModelType.NONE: 193 _patch_model_state_dict(model, options=state_dict_options) 194 _patch_optimizer_state_dict( 195 model, optimizers=optim, options=state_dict_options 196 ) 197 198 return model, optim 199 200 def _optim(self, model): 201 return torch.optim.Adam(model.parameters(), lr=0.1) 202 203 @with_comms 204 @skip_if_lt_x_gpu(4) 205 @with_temp_dir 206 @parametrize("compile", [True, False]) 207 # TODO: Previously PairwiseParallel does not shard properly, passing ModelType.FSDP_TP test where it 208 # should have failed. Disabling the failed test temporarily to unblock the deprecation of PairwiseParallel. 209 @parametrize("model_type", [ModelType.FSDP, ModelType.HSDP, ModelType.DDP]) 210 def test_e2e(self, compile, model_type): 211 self._run_e2e_test(compile, model_type) 212 213 @with_comms 214 @skip_if_lt_x_gpu(4) 215 @with_temp_dir 216 @parametrize("cache_staged_state_dict", [False, True]) 217 def test_e2e_async_cached(self, cache_staged_state_dict): 218 self._run_e2e_test( 219 compile=False, 220 model_type=ModelType.FSDP, 221 async_op=True, 222 cache_staged_state_dict=cache_staged_state_dict, 223 ) 224 225 def _run_e2e_test( 226 self, compile, model_type, async_op=False, cache_staged_state_dict=False 227 ): 228 model, optim = self._create_model(compile, ModelType.NONE) 229 _train(model, optim, train_steps=2) 230 231 dist_model, dist_optim = self._create_model(compile, model_type) 232 _, original_train_state = _train(dist_model, dist_optim, train_steps=2) 233 234 original_stateful_obj = TestStatefulObj() # tests arbitrary saving/loading 235 sd = { 236 "model": dist_model, 237 "optimizer": dist_optim, 238 "s": original_stateful_obj, 239 "train_state": original_train_state, 240 } 241 242 if async_op: 243 writer = DCP.FileSystemWriter( 244 self.temp_dir, cache_staged_state_dict=cache_staged_state_dict 245 ) 246 f = saver.async_save(sd, storage_writer=writer) 247 t = time.monotonic() 248 while not f.done(): 249 time.sleep(1) 250 print(f"still waiting... {time.monotonic() - t}") 251 252 f.result() 253 else: 254 DCP.save(sd, checkpoint_id=self.temp_dir) 255 256 loaded_stateful_obj = TestStatefulObj() 257 loaded_train_state = TestTrainState() 258 dist_model, dist_optim = self._create_model(compile, model_type) 259 260 DCP.load( 261 state_dict={ 262 "model": dist_model, 263 "optimizer": dist_optim, 264 "s": loaded_stateful_obj, 265 "train_state": loaded_train_state, 266 }, 267 checkpoint_id=self.temp_dir, 268 ) 269 270 self.assertEqual(original_stateful_obj, loaded_stateful_obj) 271 self.assertEqual(original_train_state, loaded_train_state) 272 273 # train one more step on both models 274 loss, _ = _train(model, optim, train_steps=1) 275 dist_loss, _ = _train(dist_model, dist_optim, train_steps=1) 276 self.assertEqual(loss, dist_loss) 277 278 dist_msd, dist_osd = get_state_dict(dist_model, optimizers=dist_optim) 279 model_sd, optim_sd = get_state_dict(model, optimizers=optim) 280 281 self._verify_msd(model_sd, dist_msd) 282 self._verify_osd_by_load(model, optim, self._optim(model), dist_osd) 283 284 @with_comms 285 @with_temp_dir 286 @skip_if_lt_x_gpu(4) 287 def test_different_ordered_state_dict_keys(self): 288 """Tests that the order of keys in the state dict does not matter when loading 289 If order was not accounted for, the following test would cause a deadlock. 290 """ 291 292 world_size = self.world_size 293 294 class Foo: 295 def state_dict(self): 296 return {} 297 298 def load_state_dict(self, state_dict): 299 tl = [ 300 torch.ones(2, dtype=torch.int64, device="cuda") 301 for _ in range(world_size) 302 ] 303 t = ( 304 torch.arange(2, dtype=torch.int64, device="cuda") 305 + 1 306 + 2 * dist.get_rank() 307 ) 308 dist.all_gather(tl, t, async_op=False) 309 310 class Bar: 311 def state_dict(self): 312 return {} 313 314 def load_state_dict(self, state_dict): 315 tensor = ( 316 torch.arange(2, dtype=torch.int64, device="cuda") 317 + 1 318 + 2 * dist.get_rank() 319 ) 320 dist.all_reduce(tensor, op=ReduceOp.SUM) 321 322 if self.rank == 0: 323 sd = { 324 "A": Foo(), 325 "B": Bar(), 326 } 327 else: 328 sd = { 329 "B": Bar(), 330 "A": Foo(), 331 } 332 333 DCP.save(sd, checkpoint_id=self.temp_dir) 334 DCP.load(sd, checkpoint_id=self.temp_dir) 335 336 @with_temp_dir 337 def test_no_dist(self): 338 # since comm's are not initialized in this method, `no_dist` 339 # is assumed False 340 DCP.save({}, checkpoint_id=self.temp_dir) 341 DCP.load({}, checkpoint_id=self.temp_dir) 342 343 @with_comms 344 @skip_if_lt_x_gpu(4) 345 @with_temp_dir 346 def test_partial_load(self): 347 model, optim = self._create_model(compile=False, model_type=ModelType.NONE) 348 _train(model, optim, train_steps=2) 349 350 dist_model, dist_optim = self._create_model( 351 compile=False, model_type=ModelType.FSDP 352 ) 353 _train(dist_model, dist_optim, train_steps=2) 354 355 DCP.save( 356 {"model": dist_model, "optimizer": dist_optim}, checkpoint_id=self.temp_dir 357 ) 358 359 dist_model, _ = self._create_model(compile=False, model_type=ModelType.FSDP) 360 DCP.load({"model": dist_model}, checkpoint_id=self.temp_dir) 361 362 dist_msd = get_model_state_dict(dist_model) 363 model_sd = get_model_state_dict(model) 364 self._verify_msd(model_sd, dist_msd) 365 366 # another way 367 loaded_model_sd = _load_state_dict_from_keys( 368 "model", checkpoint_id=self.temp_dir 369 )["model"] 370 self._verify_msd(model_sd, loaded_model_sd, offload_to_cpu=True) 371 372 loaded_optim_state = _load_state_dict_from_keys( 373 "optimizer.state", checkpoint_id=self.temp_dir 374 )["optimizer"]["state"] 375 self.assertNotIn("param_groups", loaded_optim_state) 376 for k, v in dist_optim.state_dict()["state"].items(): 377 for optim_key in ["exp_avg", "exp_avg_sq", "step"]: 378 self._compare_tensor( 379 loaded_optim_state[k][optim_key], v[optim_key], offload_to_cpu=True 380 ) 381 382 @with_comms 383 @skip_if_lt_x_gpu(4) 384 @with_temp_dir 385 def test_overwrite(self): 386 t1, t2 = torch.randn(10), torch.randn(10) 387 DCP.save({"random": t1}, checkpoint_id=self.temp_dir) 388 DCP.save( 389 {"random": t2}, 390 storage_writer=DCP.FileSystemWriter(self.temp_dir, overwrite=True), 391 ) 392 393 sd = {"random": torch.zeros(10)} 394 DCP.load(sd, checkpoint_id=self.temp_dir) 395 396 self.assertTrue(torch.allclose(sd["random"], t2)) 397 398 with self.assertRaisesRegex( 399 CheckpointException, ".*Checkpoint already exists.*" 400 ): 401 DCP.save( 402 {"random": t2}, 403 storage_writer=DCP.FileSystemWriter(self.temp_dir, overwrite=False), 404 ) 405 406 407class TestNoCPU(DTensorTestBase): 408 @property 409 def backend(self): 410 return "nccl" 411 412 @with_comms 413 def test_no_cpu(self): 414 with self.assertRaisesRegex( 415 AssertionError, r"A CPU backend must be enabled for async save;.*?" 416 ): 417 f = saver.async_save({}) 418 f.result() 419 420 421class TestInitStateDict(DTensorTestBase): 422 @with_temp_dir 423 def test_init_state_dict(self): 424 temp_dir = self.temp_dir 425 model = TestDummyModel() 426 optim = torch.optim.Adam(model.parameters(), lr=0.1) 427 428 state_dict_to_save = { 429 "model": get_model_state_dict(model), 430 "optimizer": get_optimizer_state_dict(model, optim), 431 } 432 DCP.save(state_dict_to_save, checkpoint_id=temp_dir) 433 434 torch.manual_seed(0) 435 model_2 = TestDummyModel() 436 # Changing the learning rate for optimizer, which is not a tensor. 437 optim_2 = torch.optim.Adam(model_2.parameters(), lr=0.2) 438 439 msd = get_model_state_dict(model_2) 440 osd = get_optimizer_state_dict(model_2, optim_2) 441 442 state_dict_to_load = {"model": msd, "optimizer": osd} 443 DCP.load(state_dict_to_load, checkpoint_id=temp_dir) 444 445 # We need to check that the two variables point to the same object in memory, 446 # since we claim DCP is in-place loading. 447 self.assertTrue(msd is state_dict_to_load["model"]) 448 self.assertTrue(osd is state_dict_to_load["optimizer"]) 449 450 # set_state_dict calls load_state_dict for model and optimizer. 451 # so we should see the optim_2.param_groups learning rate is 0.1 instead of 0.2 now. 452 set_state_dict( 453 model_2, 454 optim_2, 455 model_state_dict=state_dict_to_load["model"], 456 optim_state_dict=state_dict_to_load["optimizer"], 457 ) 458 self.assertEqual(msd, get_model_state_dict(model_2)) 459 self.assertEqual(osd, get_optimizer_state_dict(model_2, optim_2)) 460 self.assertEqual(optim_2.param_groups[0]["lr"], 0.1) 461 462 463instantiate_parametrized_tests(TestE2ESaveAndLoad) 464if __name__ == "__main__": 465 run_tests() 466