1# Owner(s): ["oncall: distributed"] 2 3import copy 4import sys 5from typing import Dict 6 7import torch 8import torch.distributed as dist 9import torch.nn as nn 10from torch.distributed._composable import checkpoint, fully_shard, replicate 11from torch.distributed._shard.sharded_tensor import ShardedTensor 12from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType 13from torch.distributed.fsdp.api import MixedPrecision, ShardingStrategy 14from torch.distributed.fsdp.wrap import ModuleWrapPolicy 15from torch.testing._internal.common_dist_composable import ( 16 CompositeModel, 17 CompositeParamModel, 18 UnitModule, 19) 20from torch.testing._internal.common_distributed import ( 21 SaveForwardInputsModel, 22 skip_if_lt_x_gpu, 23) 24from torch.testing._internal.common_fsdp import FSDPTest 25from torch.testing._internal.common_utils import ( 26 instantiate_parametrized_tests, 27 run_tests, 28 TEST_WITH_DEV_DBG_ASAN, 29) 30 31 32if not dist.is_available(): 33 print("Distributed not available, skipping tests", file=sys.stderr) 34 sys.exit(0) 35 36 37if TEST_WITH_DEV_DBG_ASAN: 38 print( 39 "Skip dev-asan as torch + multiprocessing spawn have known issues", 40 file=sys.stderr, 41 ) 42 sys.exit(0) 43 44 45class TestFSDPCheckpoint(FSDPTest): 46 @property 47 def world_size(self) -> int: 48 return 2 49 50 # TODO: Define `use_same_inputs_across_ranks` for now for BC since some 51 # test model configs do not have a simple base model to compare against. In 52 # those cases, we use the same inputs across ranks so that the averaged 53 # gradient equals the local gradient to check for parity. This means that 54 # the gradient reduction is unchecked. 55 def _test_parity( 56 self, 57 base_model: nn.Module, 58 test_model: nn.Module, 59 inp_size: torch.Size, 60 inp_device: torch.device, 61 grad_to_none: bool, 62 use_same_inputs_across_ranks: bool, 63 ): 64 LR = 0.01 65 base_optim = torch.optim.Adam(base_model.parameters(), lr=LR) 66 test_optim = torch.optim.Adam(test_model.parameters(), lr=LR) 67 68 for _ in range(5): 69 if use_same_inputs_across_ranks: 70 torch.manual_seed(0) 71 x = torch.randn(inp_size, device=inp_device) 72 test_loss = test_model(x).sum() 73 base_loss = base_model(x).sum() 74 75 self.assertEqual(test_loss, base_loss) 76 77 test_loss.backward() 78 test_optim.step() 79 test_optim.zero_grad(set_to_none=grad_to_none) 80 81 base_loss.backward() 82 base_optim.step() 83 base_optim.zero_grad(set_to_none=grad_to_none) 84 85 @skip_if_lt_x_gpu(2) 86 def test_wrap_same_submodule(self): 87 model = UnitModule(device=torch.device("cuda")) 88 89 base_model = copy.deepcopy(model) 90 91 test_model = copy.deepcopy(model) 92 # compose checkpoint and fully_shard 93 test_model.seq = checkpoint(test_model.seq) 94 test_model.seq = fully_shard( 95 test_model.seq, 96 policy=ModuleWrapPolicy({nn.Linear}), 97 ) 98 99 self.run_subtests( 100 { 101 "base_model": [base_model], 102 "test_model": [test_model], 103 "inp_size": [torch.Size((2, 100))], 104 "inp_device": [torch.device("cuda")], 105 "grad_to_none": [True, False], 106 "use_same_inputs_across_ranks": [True], 107 }, 108 self._test_parity, 109 ) 110 111 def _test_checkpoint_fsdp_submodules(self): 112 model = CompositeModel(device=torch.device("cuda")) 113 114 base_model = copy.deepcopy(model) 115 116 test_model = copy.deepcopy(model) 117 test_model.u1 = fully_shard(test_model.u1, policy=None) 118 test_model.u2 = fully_shard(test_model.u2) 119 120 test_model.u1.seq = checkpoint(test_model.u1.seq) 121 test_model.u2.seq = checkpoint(test_model.u2.seq) 122 123 self.run_subtests( 124 { 125 "base_model": [base_model], 126 "test_model": [test_model], 127 "inp_size": [torch.Size((2, 100))], 128 "inp_device": [torch.device("cuda")], 129 "grad_to_none": [True, False], 130 "use_same_inputs_across_ranks": [True], 131 }, 132 self._test_parity, 133 ) 134 135 @skip_if_lt_x_gpu(2) 136 def test_checkpoint_fsdp_submodules_non_reentrant(self): 137 self._test_checkpoint_fsdp_submodules() 138 139 @skip_if_lt_x_gpu(2) 140 def test_checkpoint_fully_shard_cast_forward_inputs(self): 141 self.run_subtests( 142 { 143 "checkpoint_strict_submodule": [False, True], 144 }, 145 self._test_checkpoint_fully_shard_cast_forward_inputs, 146 ) 147 148 def _test_checkpoint_fully_shard_cast_forward_inputs( 149 self, checkpoint_strict_submodule: bool 150 ): 151 forward_inputs: Dict[nn.Module, torch.Tensor] = {} 152 fp16_mp = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True) 153 fp32_mp = MixedPrecision(param_dtype=torch.float32, cast_forward_inputs=True) 154 155 model = SaveForwardInputsModel( 156 forward_inputs=forward_inputs, cast_forward_inputs=False 157 ).cuda() 158 x = torch.zeros(2, 100, device="cuda") 159 160 fully_shard(model.c2, mixed_precision=fp16_mp) 161 if checkpoint_strict_submodule: 162 checkpoint(model.c2.l) 163 else: 164 checkpoint(model.c2) 165 fully_shard(model, mixed_precision=fp32_mp) 166 167 loss = model(x).sum() 168 loss.backward() 169 170 self.assertEqual(forward_inputs[model].dtype, torch.float32) 171 self.assertEqual(forward_inputs[model.c1].dtype, torch.float32) 172 # Notably, check that the recomputed forward preserves the right dtype 173 self.assertEqual(forward_inputs[model.c2].dtype, torch.float16) 174 175 @skip_if_lt_x_gpu(2) 176 def test_fully_shard_replicate_correct_replicate_params(self): 177 model = CompositeParamModel(device=torch.device("cuda")) 178 # Shard Linears within UnitModule 179 fully_shard(model.u1, policy=ModuleWrapPolicy({nn.Linear})) 180 fully_shard(model.u2, policy=ModuleWrapPolicy({nn.Linear})) 181 # replicate the rest 182 replicate(model) 183 # Run fwd + bwd to initialize DDP 184 inp = torch.randn(2, 100, device="cuda") 185 model(inp).sum().backward() 186 # Ensure replicate param names are as expected, i.e. 187 # immediate parameters of model and parameters of model's non-UnitModule 188 # submodules are replicated 189 param_names = replicate.state(model)._param_names 190 replicated_modules = [ 191 (name, mod) 192 for (name, mod) in model.named_children() 193 if mod not in [model.u1, model.u2] 194 ] 195 replicated_param_names = [ 196 f"{module_name}.{n}" 197 for module_name, mod in replicated_modules 198 for n, _ in mod.named_parameters() 199 ] 200 replicated_param_names.extend( 201 [n for n, _ in model.named_parameters(recurse=False)] 202 ) 203 self.assertEqual(set(param_names), set(replicated_param_names)) 204 205 @skip_if_lt_x_gpu(2) 206 def test_checkpoint_fsdp_submodules_with_param(self): 207 model = CompositeParamModel(device=torch.device("cuda")) 208 209 base_model = copy.deepcopy(model) 210 211 test_model = copy.deepcopy(model) 212 test_model.u1.seq = checkpoint(test_model.u1.seq) 213 test_model.u2.seq = checkpoint(test_model.u2.seq) 214 test_model = fully_shard(test_model) 215 216 self.run_subtests( 217 { 218 "base_model": [base_model], 219 "test_model": [test_model], 220 "inp_size": [torch.Size((2, 100))], 221 "inp_device": [torch.device("cuda")], 222 "grad_to_none": [True, False], 223 "use_same_inputs_across_ranks": [True], 224 }, 225 self._test_parity, 226 ) 227 228 @skip_if_lt_x_gpu(2) 229 def test_checkpoint_fsdp_submodules_with_param_no_shard(self): 230 model = CompositeParamModel(device=torch.device("cuda")) 231 232 base_model = copy.deepcopy(model) 233 234 test_model = copy.deepcopy(model) 235 test_model.u1.seq = checkpoint(test_model.u1.seq) 236 test_model.u2.seq = checkpoint(test_model.u2.seq) 237 test_model = fully_shard(test_model, strategy=ShardingStrategy.NO_SHARD) 238 239 self.run_subtests( 240 { 241 "base_model": [base_model], 242 "test_model": [test_model], 243 "inp_size": [torch.Size((2, 100))], 244 "inp_device": [torch.device("cuda")], 245 "grad_to_none": [True, False], 246 "use_same_inputs_across_ranks": [True], 247 }, 248 self._test_parity, 249 ) 250 251 @skip_if_lt_x_gpu(2) 252 def test_composable_fsdp_replicate(self): 253 # Verify how the APIs can be composed, e.g. if both `fully_shard` and 254 # `replicate` are applied on the same module, it should raise exception. 255 model = CompositeModel(device=torch.device("cuda")) 256 fully_shard(model.l1) 257 with self.assertRaisesRegex(RuntimeError, "Cannot apply .*replicate"): 258 replicate(model.l1) 259 replicate(model.l2) # should not raise 260 261 @skip_if_lt_x_gpu(2) 262 def test_fully_shard_replicate_composability(self): 263 """ 264 Tests composing ``fully_shard`` and ``replicate``. To save unit test 265 time, we run the different configs in subtests. 266 """ 267 self.run_subtests( 268 { 269 "config": [ 270 "1fm,1r", 271 "1r,1fm", 272 "1r,1fa", 273 "1r1fm,1fm", 274 "1r1fa,1fm", 275 "1fm1fm,1r1r,1fm", 276 ] 277 }, 278 self._test_replicate_in_fully_shard, 279 ) 280 281 def _test_replicate_in_fully_shard(self, config: str): 282 """ 283 To interpret the config, each comma delineates a level in the module 284 tree ordered bottom-up; 'r' means ``replicate``; 'f' means 285 ``fully_shard``; 'a' means auto wrap; and 'm' means manual wrap. 286 """ 287 # Set the seed to ensure that all ranks initialize the same model 288 torch.manual_seed(0) 289 if config == "1fm,1r": 290 base_model = CompositeModel(device=torch.device("cuda")) 291 test_model = copy.deepcopy(base_model) 292 fully_shard(test_model.l1) 293 replicate(test_model) 294 elif config == "1r,1fm": 295 base_model = CompositeParamModel(torch.device("cuda")) 296 test_model = copy.deepcopy(base_model) 297 replicate(test_model.u1) 298 fully_shard(test_model) 299 elif config == "1r,1fa": 300 base_model = CompositeParamModel(torch.device("cuda")) 301 test_model = copy.deepcopy(base_model) 302 replicate(test_model.u1) 303 fully_shard(test_model, policy=ModuleWrapPolicy({UnitModule})) 304 elif config == "1r1fm,1fm": 305 base_model = CompositeParamModel(torch.device("cuda")) 306 test_model = copy.deepcopy(base_model) 307 replicate(test_model.u1) 308 fully_shard(test_model.u2) 309 fully_shard(test_model) 310 elif config == "1r1fa,1fm": 311 base_model = CompositeParamModel(torch.device("cuda")) 312 test_model = copy.deepcopy(base_model) 313 replicate(test_model.u1) 314 fully_shard(test_model.u2, policy=ModuleWrapPolicy({UnitModule})) 315 fully_shard(test_model) 316 elif config == "1fm1fm,1r1r,1fm": 317 base_model = CompositeParamModel(torch.device("cuda")) 318 test_model = copy.deepcopy(base_model) 319 fully_shard(test_model.u1.seq) 320 fully_shard(test_model.u2.seq) 321 replicate(test_model.u1) 322 replicate(test_model.u2) 323 fully_shard(test_model) 324 else: 325 raise ValueError(f"Unknown config: {config}") 326 # Apply data parallelism to the base model for parity since we apply 327 # data parallelism to the test model 328 replicate(base_model) 329 330 # Set the seed to ensure that ranks get different input data 331 torch.manual_seed(self.rank + 1) 332 self._test_parity( 333 base_model, 334 test_model, 335 torch.Size((2, 100)), 336 torch.device("cuda"), 337 True, 338 False, 339 ) 340 341 @skip_if_lt_x_gpu(2) 342 def test_state_dict_fsdp_submodules(self): 343 model = CompositeModel(device=torch.device("cuda")) 344 345 full_shard_args = {"strategy": ShardingStrategy.FULL_SHARD} 346 no_shard_args = {"strategy": ShardingStrategy.NO_SHARD} 347 348 model.u1 = fully_shard(model.u1, **full_shard_args) 349 model.u2 = fully_shard(model.u2, **no_shard_args) 350 351 FSDP.set_state_dict_type( 352 model, 353 StateDictType.SHARDED_STATE_DICT, 354 ) 355 356 state_dict = model.state_dict() 357 for fqn, tensor in state_dict.items(): 358 if "u1" in fqn: 359 self.assertIsInstance(tensor, ShardedTensor) 360 elif "u2" in fqn: 361 self.assertIsInstance(tensor, torch.Tensor) 362 # Ensure that get_state_dict_type can still correctly get the settings. 363 _ = FSDP.get_state_dict_type(model) 364 365 366instantiate_parametrized_tests(TestFSDPCheckpoint) 367 368 369if __name__ == "__main__": 370 run_tests() 371