1# Owner(s): ["oncall: distributed"] 2 3import copy 4import functools 5from typing import Dict, List, Optional, Union 6 7import torch 8import torch.distributed as dist 9import torch.distributed._functional_collectives as funcol 10import torch.nn as nn 11from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy 12from torch.distributed._composable.fsdp._fsdp_collectives import ( 13 _get_gradient_divide_factors, 14) 15from torch.testing._internal.common_distributed import ( 16 requires_nccl_version, 17 SaveForwardInputsModel, 18 skip_if_lt_x_gpu, 19) 20from torch.testing._internal.common_fsdp import ( 21 check_sharded_parity, 22 FSDPTest, 23 FSDPTestMultiThread, 24 MLP, 25 patch_reduce_scatter, 26 reduce_scatter_with_assert, 27) 28from torch.testing._internal.common_utils import run_tests 29 30 31class TestFullyShardMixedPrecisionTraining(FSDPTest): 32 @property 33 def world_size(self) -> int: 34 return min(4, torch.cuda.device_count()) 35 36 def _init_models_and_optims( 37 self, 38 reshard_after_forward: Union[bool, int], 39 param_dtype: Optional[torch.dtype], 40 reduce_dtype: Optional[torch.dtype], 41 ): 42 torch.manual_seed(42) 43 model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)]) 44 ref_model = copy.deepcopy(model).cuda() 45 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) 46 mp_policy = MixedPrecisionPolicy( 47 param_dtype=param_dtype, reduce_dtype=reduce_dtype 48 ) 49 fully_shard_fn = functools.partial( 50 fully_shard, 51 reshard_after_forward=reshard_after_forward, 52 mp_policy=mp_policy, 53 ) 54 for mlp in model: 55 fully_shard_fn(mlp) 56 fully_shard_fn(model) 57 optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True) 58 return ref_model, ref_optim, model, optim 59 60 @skip_if_lt_x_gpu(2) 61 @requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives") 62 def test_compute_dtype(self): 63 self.run_subtests( 64 { 65 "param_dtype": [torch.bfloat16, torch.float16], 66 "reshard_after_forward": [False, True, 2], 67 }, 68 self._test_compute_dtype, 69 ) 70 71 def _test_compute_dtype( 72 self, param_dtype: torch.dtype, reshard_after_forward: Union[bool, int] 73 ): 74 ref_model, ref_optim, model, optim = self._init_models_and_optims( 75 reshard_after_forward, param_dtype=param_dtype, reduce_dtype=None 76 ) 77 ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype) 78 orig_reduce_scatter = dist.reduce_scatter_tensor 79 80 def assert_fn(output: torch.Tensor): 81 self.assertEqual(output.dtype, param_dtype) 82 83 reduce_scatter = functools.partial( 84 reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn 85 ) 86 predivide_factor, postdivide_factor = _get_gradient_divide_factors( 87 self.process_group, all_reduce_group=None, reduce_dtype=param_dtype 88 ) 89 90 torch.manual_seed(42 + self.rank + 1) 91 inp = torch.randn((4, 16), device="cuda", dtype=param_dtype) 92 for iter_idx in range(10): 93 optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 94 fsdp_loss = model(inp).sum() 95 with patch_reduce_scatter(reduce_scatter): 96 fsdp_loss.backward() 97 optim.step() 98 99 ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 100 ref_loss = ref_model_bf16(inp.to(param_dtype)).sum() 101 ref_loss.backward() 102 for param in ref_model_bf16.parameters(): 103 # Use reduce-scatter -> all-gather as all-reduce because for 104 # world size >=4, NCCL all-reduce shows numeric differences 105 # compared with NCCL reduce-scatter 106 if predivide_factor is not None and predivide_factor > 1: 107 param.grad.div_(predivide_factor) 108 elif predivide_factor is None: 109 param.grad.div_(self.world_size) 110 output = torch.zeros_like(torch.chunk(param.grad, self.world_size)[0]) 111 dist.reduce_scatter_tensor(output, param.grad) 112 dist.all_gather_into_tensor(param.grad, output) 113 if postdivide_factor is not None and postdivide_factor > 1: 114 param.grad.div_(postdivide_factor) 115 for param_fp32, param_bf16 in zip( 116 ref_model.parameters(), ref_model_bf16.parameters() 117 ): 118 param_fp32.grad = param_bf16.grad.to(param_fp32.dtype) 119 param_bf16.grad = None 120 ref_optim.step() # fp32 optimizer step 121 for param_fp32, param_bf16 in zip( 122 ref_model.parameters(), ref_model_bf16.parameters() 123 ): 124 param_bf16.detach().copy_(param_fp32) 125 126 self.assertEqual(fsdp_loss, ref_loss) 127 check_sharded_parity(self, ref_model, model) 128 129 @skip_if_lt_x_gpu(2) 130 @requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives") 131 def test_reduce_dtype(self): 132 self.run_subtests( 133 {"reshard_after_forward": [False, True, 2]}, 134 self._test_reduce_dtype_fp32_reduce, 135 ) 136 self.run_subtests( 137 {"reshard_after_forward": [False, True, 2]}, 138 self._test_reduce_dtype_bf16_reduce, 139 ) 140 141 def _test_reduce_dtype_fp32_reduce(self, reshard_after_forward: Union[bool, int]): 142 param_dtype, reduce_dtype = torch.bfloat16, torch.float32 143 ref_model, ref_optim, model, optim = self._init_models_and_optims( 144 reshard_after_forward, param_dtype=param_dtype, reduce_dtype=reduce_dtype 145 ) 146 ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype) 147 orig_reduce_scatter = dist.reduce_scatter_tensor 148 149 def assert_fn(output: torch.Tensor): 150 self.assertEqual(output.dtype, reduce_dtype) 151 152 reduce_scatter = functools.partial( 153 reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn 154 ) 155 torch.manual_seed(42 + self.rank + 1) 156 inp = torch.randn((4, 16), device="cuda", dtype=param_dtype) 157 for iter_idx in range(10): 158 optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 159 fsdp_loss = model(inp).sum() 160 with patch_reduce_scatter(reduce_scatter): 161 fsdp_loss.backward() 162 optim.step() 163 164 ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 165 ref_loss = ref_model_bf16(inp.to(param_dtype)).sum() 166 ref_loss.backward() 167 for param in ref_model_bf16.parameters(): 168 param.grad.data = param.grad.to(torch.float32) 169 dist.all_reduce(param.grad) # fp32 reduction 170 param.grad.div_(self.world_size) 171 for param_fp32, param_bf16 in zip( 172 ref_model.parameters(), ref_model_bf16.parameters() 173 ): 174 param_fp32.grad = param_bf16.grad 175 param_bf16.grad = None 176 ref_optim.step() # fp32 optimizer step 177 for param_fp32, param_bf16 in zip( 178 ref_model.parameters(), ref_model_bf16.parameters() 179 ): 180 param_bf16.detach().copy_(param_fp32) 181 182 self.assertEqual(fsdp_loss, ref_loss) 183 check_sharded_parity(self, ref_model, model) 184 185 def _test_reduce_dtype_bf16_reduce(self, reshard_after_forward: Union[bool, int]): 186 param_dtype, reduce_dtype = torch.float32, torch.bfloat16 187 ref_model, ref_optim, model, optim = self._init_models_and_optims( 188 reshard_after_forward, param_dtype=param_dtype, reduce_dtype=reduce_dtype 189 ) 190 group = dist.distributed_c10d._get_default_group() 191 orig_reduce_scatter = dist.reduce_scatter_tensor 192 193 def assert_fn(output: torch.Tensor): 194 self.assertEqual(output.dtype, reduce_dtype) 195 196 reduce_scatter = functools.partial( 197 reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn 198 ) 199 torch.manual_seed(42 + self.rank + 1) 200 inp = torch.randn((4, 16), device="cuda", dtype=param_dtype) 201 for iter_idx in range(10): 202 optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 203 fsdp_loss = model(inp).sum() 204 with patch_reduce_scatter(reduce_scatter): 205 fsdp_loss.backward() 206 optim.step() 207 208 ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 209 ref_loss = ref_model(inp).sum() 210 ref_loss.backward() 211 for param in ref_model.parameters(): 212 param_grad = param.grad.to(reduce_dtype) 213 # Use reduce-scatter -> all-gather to implement all-reduce 214 # since for world size >2, bf16 all-reduce and reduce-scatter 215 # have numeric differences 216 sharded_grad = funcol.reduce_scatter_tensor( 217 param_grad, scatter_dim=0, reduceOp="avg", group=group 218 ) # bf16 reduction 219 param.grad = funcol.all_gather_tensor( 220 sharded_grad, gather_dim=0, group=group 221 ).to( 222 param.dtype 223 ) # upcast to fp32 224 ref_optim.step() # fp32 optimizer step 225 226 self.assertEqual(fsdp_loss, ref_loss) 227 check_sharded_parity(self, ref_model, model) 228 229 @skip_if_lt_x_gpu(2) 230 def test_grad_acc_with_reduce_dtype(self): 231 """ 232 Tests that gradient accumulation without reduce-scatter when using 233 bf16 compute and fp32 reduction accumulates the unsharded gradients in 234 fp32. 235 """ 236 self.run_subtests( 237 {"reshard_after_forward": [True, False]}, 238 self._test_grad_acc_with_reduce_dtype, 239 ) 240 241 def _test_grad_acc_with_reduce_dtype(self, reshard_after_forward: bool): 242 torch.manual_seed(42) 243 param_dtype, reduce_dtype = (torch.bfloat16, torch.float32) 244 mp_policy = MixedPrecisionPolicy( 245 param_dtype=param_dtype, reduce_dtype=reduce_dtype 246 ) 247 model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)]) 248 # To emulate the mixed precision implementation where forward/backward 249 # compute use bf16 and optimizer uses fp32, we maintain both an fp32 250 # and a bf16 copy of the reference model 251 ref_model = copy.deepcopy(model).cuda() 252 ref_model_compute = copy.deepcopy(ref_model).to(param_dtype) 253 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) 254 for mlp in model: 255 fully_shard( 256 mlp, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy 257 ) 258 fully_shard( 259 model, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy 260 ) 261 optim = torch.optim.Adam(model.parameters(), lr=1e-2) 262 orig_reduce_scatter = dist.reduce_scatter_tensor 263 264 def assert_fn(output: torch.Tensor): 265 self.assertEqual(output.dtype, reduce_dtype) 266 267 reduce_scatter = functools.partial( 268 reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn 269 ) 270 torch.manual_seed(42 + self.rank + 1) 271 device = torch.device("cuda") 272 # Train on the same input to avoid loss explosion 273 num_microbatches = 4 274 inp = torch.randn((2 * num_microbatches, 16), device=device, dtype=param_dtype) 275 for iter_idx in range(10): 276 microbatch_inps = torch.chunk(inp, 4) 277 for microbatch_idx in range(num_microbatches): 278 is_last_microbatch = microbatch_idx == num_microbatches - 1 279 model.set_requires_gradient_sync(is_last_microbatch) 280 model.set_reshard_after_backward( 281 is_last_microbatch or reshard_after_forward 282 ) 283 losses: List[torch.Tensor] = [] 284 for _model in (ref_model_compute, model): 285 losses.append( 286 _model(microbatch_inps[microbatch_idx].detach()).sum() 287 ) 288 self.assertEqual(losses[-1].dtype, param_dtype) 289 with patch_reduce_scatter(reduce_scatter): 290 losses[-1].backward() 291 self.assertEqual(losses[0], losses[1]) 292 # Manually accumulate gradients into the base reference model 293 # from the compute reference model in fp32 294 for ref_param, ref_param_compute in zip( 295 ref_model.parameters(), ref_model_compute.parameters() 296 ): 297 self.assertTrue(ref_param_compute.grad is not None) 298 self.assertEqual(ref_param.dtype, torch.float32) 299 if ref_param.grad is not None: 300 ref_param.grad += ref_param_compute.grad 301 else: 302 ref_param.grad = ref_param_compute.grad.to(ref_param.dtype) 303 ref_param_compute.grad = None 304 # Manually reduce gradients for the reference model on the last 305 # microbatch to implement data parallelism 306 if is_last_microbatch: 307 for ref_param in ref_model.parameters(): 308 self.assertTrue(ref_param.grad is not None) 309 dist.all_reduce(ref_param.grad) 310 ref_param.grad /= self.world_size 311 check_sharded_parity(self, ref_model, model) 312 ref_optim.step() 313 optim.step() 314 ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 315 optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 316 # Manually copy parameters from the base reference model to the 317 # compute reference model to run the optimizer step for the latter 318 for ref_param, ref_param_compute in zip( 319 ref_model.parameters(), ref_model_compute.parameters() 320 ): 321 ref_param_compute.detach().copy_(ref_param) 322 323 324class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread): 325 @property 326 def world_size(self) -> int: 327 return 2 328 329 @skip_if_lt_x_gpu(1) 330 def test_float16_on_one_submodule(self): 331 x = torch.zeros(2, 100, device="cuda") 332 333 # Subtest 1: use fp16 on the second child submodule -- does not require 334 # any additional casting logic 335 forward_inputs: Dict[str, nn.Module] = {} 336 model = SaveForwardInputsModel( 337 forward_inputs, 338 cast_forward_inputs=False, 339 ).cuda() 340 fully_shard(model.c2, mp_policy=MixedPrecisionPolicy(param_dtype=torch.float16)) 341 fully_shard(model) 342 model(x).sum().backward() 343 self.assertEqual(forward_inputs[model].dtype, torch.float32) 344 self.assertEqual(forward_inputs[model.c1].dtype, torch.float32) 345 self.assertEqual(forward_inputs[model.c2].dtype, torch.float16) 346 347 # Subtest 2: use fp16 on the second child module, where the user module 348 # owns the cast 349 forward_inputs: Dict[nn.Module, torch.Tensor] = {} 350 model = SaveForwardInputsModel( 351 forward_inputs=forward_inputs, cast_forward_inputs=True 352 ).cuda() 353 fully_shard( 354 model.c2, 355 mp_policy=MixedPrecisionPolicy( 356 param_dtype=torch.float16, cast_forward_inputs=False 357 ), 358 ) 359 fully_shard(model) 360 model(x).sum().backward() 361 self.assertEqual(forward_inputs[model].dtype, torch.float32) 362 self.assertEqual(forward_inputs[model.c1].dtype, torch.float32) 363 self.assertEqual(forward_inputs[model.c2].dtype, torch.float32) 364 365 # Subtest 3: use fp16 on the first child module and specify its output 366 # dtype so that the second child module does not need to cast 367 forward_inputs: Dict[nn.Module, torch.Tensor] = {} 368 model = SaveForwardInputsModel( 369 forward_inputs=forward_inputs, cast_forward_inputs=False 370 ).cuda() 371 fully_shard( 372 model.c1, 373 mp_policy=MixedPrecisionPolicy( 374 param_dtype=torch.float16, output_dtype=torch.float32 375 ), 376 ) 377 fully_shard(model) 378 model(x).sum().backward() 379 self.assertEqual(forward_inputs[model].dtype, torch.float32) 380 self.assertEqual(forward_inputs[model.c1].dtype, torch.float16) 381 self.assertEqual(forward_inputs[model.c2].dtype, torch.float32) 382 383 @skip_if_lt_x_gpu(1) 384 def test_submodules_with_external_inputs(self): 385 self.run_subtests( 386 {"enable_submodule_cast": [False, True]}, 387 self._test_submodules_with_external_inputs, 388 ) 389 390 def _test_submodules_with_external_inputs(self, enable_submodule_cast: bool): 391 class ToyModule(nn.Module): 392 def __init__(self, forward_inputs: Dict[str, torch.Tensor]) -> None: 393 super().__init__() 394 self.l = nn.Linear(100, 100) 395 self.forward_inputs = forward_inputs 396 397 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 398 self.forward_inputs["l2_input_x"] = x 399 self.forward_inputs["l2_input_y"] = y 400 return self.l(x) 401 402 class ToyModel(nn.Module): 403 def __init__(self, forward_inputs: Dict[str, torch.Tensor]) -> None: 404 super().__init__() 405 self.l1 = nn.Linear(100, 100) 406 self.l2 = ToyModule(forward_inputs) 407 self.forward_inputs = forward_inputs 408 409 def forward(self, x: torch.Tensor) -> torch.Tensor: 410 self.forward_inputs["model_input_x"] = x 411 y = torch.ones( 412 2, 100, device="cuda", dtype=torch.float32 413 ) # external input 414 return self.l2(self.l1(x), y) 415 416 forward_inputs: Dict[str, torch.Tensor] = {} 417 model = ToyModel(forward_inputs).cuda() 418 x = torch.zeros(2, 100, device="cuda", dtype=torch.float32) 419 fully_shard( 420 model.l2, 421 mp_policy=MixedPrecisionPolicy( 422 param_dtype=torch.float16, cast_forward_inputs=enable_submodule_cast 423 ), 424 ) 425 fully_shard(model, mp_policy=MixedPrecisionPolicy(param_dtype=torch.float16)) 426 model(x).sum().backward() 427 428 # If we enable `model.l2` to cast (as default), then `l2_input_y` gets 429 # cast to fp16, and if we disable, then it says as fp32. 430 self.assertEqual(forward_inputs["model_input_x"].dtype, torch.float16) 431 self.assertEqual(forward_inputs["l2_input_x"].dtype, torch.float16) 432 self.assertEqual( 433 forward_inputs["l2_input_y"].dtype, 434 torch.float16 if enable_submodule_cast else torch.float32, 435 ) 436 437 @skip_if_lt_x_gpu(1) 438 @requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives") 439 def test_norm_modules_bf16(self): 440 mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) 441 self._test_norm_modules(mp_policy) 442 443 @skip_if_lt_x_gpu(1) 444 def test_norm_modules_fp16(self): 445 mp_policy = MixedPrecisionPolicy(param_dtype=torch.float16) 446 self._test_norm_modules(mp_policy) 447 448 def _test_norm_modules(self, mp_policy: MixedPrecisionPolicy): 449 def inner(model: nn.Module, x: torch.Tensor): 450 # Run forward and backward to check for no type mismatch errors 451 z = model(x) 452 self.assertEqual(z.dtype, mp_policy.param_dtype) 453 z.sum().backward() 454 455 # Layer norm 456 model = nn.Sequential(nn.Linear(32, 32), nn.LayerNorm(32), nn.Linear(32, 32)) 457 for module in (model[0], model[1], model[2], model): 458 fully_shard(module, mp_policy=mp_policy) 459 inner(model, torch.randn((4, 32))) 460 461 # Batch norm 1D 462 model = nn.Sequential(nn.Linear(32, 32), nn.BatchNorm1d(32), nn.Linear(32, 32)) 463 for module in (model[0], model[1], model[2], model): 464 fully_shard(module, mp_policy=mp_policy) 465 inner(model, torch.randn((4, 32))) 466 467 # Batch norm 2D: error in backward from buffer dtype mismatch 468 model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3)) 469 for module in (model[0], model[1], model[2], model): 470 fully_shard(module, mp_policy=mp_policy) 471 with self.assertRaisesRegex(RuntimeError, "Expected running_mean to have type"): 472 # Errors in batch norm 2D backward 473 inner(model, torch.randn((3, 1, 9, 9))) 474 475 # Batch norm 2D: cast buffers down to lower precision 476 model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3)) 477 for module in (model[0], model[1], model[2], model): 478 fully_shard(module, mp_policy=mp_policy) 479 # Casting batch norm buffers to the lower precision allows backward 480 model[1].running_mean = model[1].running_mean.to(mp_policy.param_dtype) 481 model[1].running_var = model[1].running_var.to(mp_policy.param_dtype) 482 inner(model, torch.randn((3, 1, 9, 9))) 483 484 # Batch norm 2D: use special mixed precision policy 485 model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3)) 486 bn_mp_policy = MixedPrecisionPolicy(output_dtype=mp_policy.param_dtype) 487 fully_shard(model[1], mp_policy=bn_mp_policy) 488 for module in (model[0], model[2], model): 489 fully_shard(module, mp_policy=mp_policy) 490 inner(model, torch.randn((3, 1, 9, 9))) 491 492 493if __name__ == "__main__": 494 run_tests() 495