1# Owner(s): ["oncall: distributed"] 2 3from copy import deepcopy 4 5import torch 6import torch.nn as nn 7from torch.distributed._tensor import ( 8 DeviceMesh, 9 distribute_module, 10 distribute_tensor, 11 DTensor, 12 Replicate, 13 Shard, 14) 15from torch.testing._internal.common_utils import run_tests 16from torch.testing._internal.distributed._tensor.common_dtensor import ( 17 DTensorTestBase, 18 MLPModule, 19 with_comms, 20) 21 22 23# shard function to do full sharding on all parameters of a module 24def shard_fn(name, module, device_mesh): 25 if isinstance(module, nn.Linear): 26 for name, param in module.named_parameters(): 27 dist_param = torch.nn.Parameter( 28 distribute_tensor(param, device_mesh, [Shard(0)]) 29 ) 30 # make sure partial sum get cleared after backward() 31 dist_param.register_hook( 32 lambda grad: grad.redistribute(placements=[Shard(0)]) 33 ) 34 module.register_parameter(name, dist_param) 35 36 37# prepare input 38def input_fn(mod, inputs, device_mesh): 39 # split the input tensor to be sharded input 40 dist_inp = distribute_tensor(inputs[0], device_mesh, [Shard(0)]) 41 return dist_inp 42 43 44# prepare output to be local torch.Tensor 45def output_fn(mod, outputs, device_mesh): 46 assert isinstance(outputs, DTensor) 47 return outputs.redistribute(placements=[Replicate()] * device_mesh.ndim).to_local() 48 49 50class TestDTensorOptimizer(DTensorTestBase): 51 def _assert_optimizer( 52 self, 53 mesh, 54 model, 55 optim, 56 dist_model, 57 dist_optim, 58 inputs, 59 *, 60 rtol: float = 1.3e-6, 61 atol: float = 1e-5, 62 ): 63 for iter_idx in range(2): 64 # run forward/backward/optim for original model 65 optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 66 out = model(inputs) 67 loss = out.sum() 68 loss.backward() 69 optim.step() 70 71 # run forward/backward/optim for distributed model 72 dist_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 73 dist_out = dist_model(inputs) 74 dist_loss = dist_out.sum() 75 dist_loss.backward() 76 dist_optim.step() 77 78 # check that the optimizer update parameters with same numerics 79 for p1, p2 in zip(model.parameters(), dist_model.parameters()): 80 p2 = p2.full_tensor() 81 # Default 'rtol' and 'atol' for attr:`~torch.float32` are ``1.3e-6`` and ``1e-5`` 82 self.assertEqual(p1, p2, atol=atol, rtol=rtol) 83 84 def test_optimizer_foreach_supported_types_include_DTensor(self): 85 from torch.optim.optimizer import _foreach_supported_types 86 87 self.assertTrue(DTensor in _foreach_supported_types) 88 89 @with_comms 90 def test_adam_1d_sharding(self): 91 mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 92 93 # lr as a Tensor is not supported for capturable=False and foreach=True 94 adam_float_lr_configs = [ 95 {"lr": 0.1, "foreach": False}, 96 {"lr": 0.1, "weight_decay": 0.05, "foreach": False}, 97 {"lr": 0.1, "weight_decay": 0.05}, 98 {"lr": 0.1, "weight_decay": 0.05, "amsgrad": True}, 99 { 100 "lr": 0.1, 101 "weight_decay": 0.05, 102 "maximize": True, 103 "amsgrad": True, 104 }, 105 ] 106 fused_adam_float_lr_configs = [ 107 {"lr": 0.1, "fused": True}, 108 {"lr": 0.1, "weight_decay": 0.05, "amsgrad": True, "fused": True}, 109 { 110 "lr": 0.1, 111 "weight_decay": 0.05, 112 "maximize": True, 113 "amsgrad": True, 114 "fused": True, 115 }, 116 ] 117 # lr could be a Tensor or a float when fused=True for adam optimizer 118 fused_adam_tensor_lr_configs = [ 119 {**config, "lr": torch.tensor(0.1)} 120 for config in fused_adam_float_lr_configs 121 ] 122 fused_adam_tensor_lr_configs.extend( 123 [ 124 {**config, "lr": torch.tensor([0.1])} 125 for config in fused_adam_float_lr_configs 126 ] 127 ) 128 adam_configs = [ 129 *adam_float_lr_configs, 130 *fused_adam_float_lr_configs, 131 *fused_adam_tensor_lr_configs, 132 ] 133 134 for config in adam_configs: 135 mod = MLPModule(self.device_type) 136 opt = torch.optim.Adam(mod.parameters(), **config) 137 138 dist_mod = distribute_module( 139 deepcopy(mod), mesh, shard_fn, input_fn, output_fn 140 ) 141 dist_opt = torch.optim.Adam(dist_mod.parameters(), **config) 142 143 # use ones to make sure the single machine model have the same input 144 # on different ranks 145 inp = torch.ones(8, 10, device=self.device_type) 146 self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) 147 148 @with_comms 149 def test_adamw_1d_sharding(self): 150 mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 151 152 # lr as a Tensor is not supported for capturable=False and foreach=True 153 adamw_float_lr_configs = [ 154 {"lr": 0.1, "foreach": False}, 155 {"lr": 0.1, "weight_decay": 0.05, "foreach": False}, 156 {"lr": 0.1, "weight_decay": 0.05}, 157 { 158 "lr": 0.1, 159 "betas": (0.6, 0.66), 160 "eps": 1e-6, 161 "weight_decay": 0.05, 162 "amsgrad": True, 163 }, 164 { 165 "lr": 0.1, 166 "betas": (0.6, 0.66), 167 "eps": 1e-6, 168 "weight_decay": 0.05, 169 "maximize": True, 170 "amsgrad": True, 171 }, 172 ] 173 fused_adamw_float_lr_configs = [ 174 {"lr": 0.1, "weight_decay": 0.05, "fused": True}, 175 { 176 "lr": 0.1, 177 "betas": (0.6, 0.66), 178 "eps": 1e-6, 179 "weight_decay": 0.05, 180 "amsgrad": True, 181 "fused": True, 182 }, 183 { 184 "lr": 0.1, 185 "betas": (0.6, 0.66), 186 "eps": 1e-6, 187 "weight_decay": 0.05, 188 "maximize": True, 189 "amsgrad": True, 190 "fused": True, 191 }, 192 ] 193 # lr could be a Tensor or a float when fused=True for adamW optimizer 194 fused_adamw_tensor_lr_configs = [ 195 {**config, "lr": torch.tensor(0.1)} 196 for config in fused_adamw_float_lr_configs 197 ] 198 fused_adamw_tensor_lr_configs.extend( 199 [ 200 {**config, "lr": torch.tensor([0.1])} 201 for config in fused_adamw_float_lr_configs 202 ] 203 ) 204 adamw_configs = [ 205 *adamw_float_lr_configs, 206 *fused_adamw_float_lr_configs, 207 *fused_adamw_tensor_lr_configs, 208 ] 209 210 for config in adamw_configs: 211 mod = MLPModule(self.device_type) 212 opt = torch.optim.AdamW(mod.parameters(), **config) 213 214 dist_mod = distribute_module( 215 deepcopy(mod), mesh, shard_fn, input_fn, output_fn 216 ) 217 dist_opt = torch.optim.AdamW(dist_mod.parameters(), **config) 218 219 # use ones to make sure the single machine model have the same input 220 # on different ranks 221 inp = torch.ones(8, 10, device=self.device_type) 222 self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) 223 224 @with_comms 225 def test_sgd_1d_sharding(self): 226 mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 227 228 sgd_configs = [ 229 {"lr": 0.1, "foreach": False}, 230 {"lr": 0.1, "momentum": 0.05, "foreach": False}, 231 {"lr": 0.1, "momentum": 0.05}, 232 {"lr": 0.1, "momentum": 0.06, "dampening": 0.07}, 233 { 234 "lr": 0.1, 235 "momentum": 0.08, 236 "weight_decay": 0.05, 237 "nesterov": True, 238 "maximize": True, 239 "foreach": False, 240 }, 241 { 242 "lr": 0.1, 243 "momentum": 0.08, 244 "weight_decay": 0.05, 245 "nesterov": True, 246 "maximize": True, 247 }, 248 ] 249 250 for config in sgd_configs: 251 mod = MLPModule(self.device_type) 252 opt = torch.optim.SGD(mod.parameters(), **config) 253 254 dist_mod = distribute_module( 255 deepcopy(mod), mesh, shard_fn, input_fn, output_fn 256 ) 257 dist_opt = torch.optim.SGD(dist_mod.parameters(), **config) 258 259 # use ones to make sure the single machine model have the same input 260 # on different ranks 261 inp = torch.ones(8, 10, device=self.device_type) 262 self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) 263 264 @with_comms 265 def test_adagrad_1d_sharding(self): 266 mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 267 268 adagrad_configs = [ 269 {"lr": 0.1, "foreach": False}, 270 {"lr": 0.1, "lr_decay": 0.05, "foreach": False}, 271 {"lr": 0.1, "lr_decay": 0.02, "weight_decay": 0.05, "foreach": False}, 272 { 273 "lr": 0.1, 274 "lr_decay": 0.02, 275 "weight_decay": 0.05, 276 "initial_accumulator_value": 0.03, 277 "foreach": False, 278 }, 279 { 280 "lr": 0.1, 281 "lr_decay": 0.02, 282 "weight_decay": 0.05, 283 "initial_accumulator_value": 0.03, 284 "eps": 1e-6, 285 "foreach": False, 286 }, 287 { 288 "lr": 0.1, 289 "lr_decay": 0.02, 290 "weight_decay": 0.05, 291 "initial_accumulator_value": 0.03, 292 "eps": 1e-6, 293 "maximize": True, 294 "foreach": False, 295 }, 296 { 297 "lr": 0.1, 298 "lr_decay": 0.02, 299 "weight_decay": 0.05, 300 "initial_accumulator_value": 0.03, 301 "eps": 1e-6, 302 "maximize": True, 303 }, 304 ] 305 306 for config in adagrad_configs: 307 mod = MLPModule(self.device_type) 308 opt = torch.optim.Adagrad(mod.parameters(), **config) 309 310 dist_mod = distribute_module( 311 deepcopy(mod), mesh, shard_fn, input_fn, output_fn 312 ) 313 dist_opt = torch.optim.Adagrad(dist_mod.parameters(), **config) 314 315 # use ones to make sure the single machine model have the same input 316 # on different ranks 317 inp = torch.ones(8, 10, device=self.device_type) 318 self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) 319 320 @with_comms 321 def test_RMSprop_1d_sharding(self): 322 mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 323 324 RMSprop_configs = [ 325 {"lr": 0.1, "foreach": False}, 326 {"lr": 0.1, "alpha": 0.85, "foreach": False}, 327 {"lr": 0.1, "alpha": 0.88, "eps": 1e-6, "foreach": False}, 328 { 329 "lr": 0.1, 330 "alpha": 0.88, 331 "eps": 1e-6, 332 "weight_decay": 0.05, 333 "foreach": False, 334 }, 335 { 336 "lr": 0.1, 337 "alpha": 0.88, 338 "eps": 1e-6, 339 "weight_decay": 0.05, 340 "momentum": 0.9, 341 "foreach": False, 342 }, 343 { 344 "lr": 0.1, 345 "alpha": 0.88, 346 "eps": 1e-6, 347 "weight_decay": 0.05, 348 "momentum": 0.9, 349 "centered": True, 350 "foreach": False, 351 }, 352 { 353 "lr": 0.1, 354 "alpha": 0.88, 355 "eps": 1e-6, 356 "weight_decay": 0.05, 357 "momentum": 0.9, 358 "centered": True, 359 "maximize": True, 360 "foreach": False, 361 }, 362 { 363 "lr": 0.1, 364 "alpha": 0.88, 365 "eps": 1e-6, 366 "weight_decay": 0.05, 367 "momentum": 0.9, 368 "centered": True, 369 "maximize": True, 370 }, 371 ] 372 373 for config in RMSprop_configs: 374 mod = MLPModule(self.device_type) 375 opt = torch.optim.RMSprop(mod.parameters(), **config) 376 377 dist_mod = distribute_module( 378 deepcopy(mod), mesh, shard_fn, input_fn, output_fn 379 ) 380 dist_opt = torch.optim.RMSprop(dist_mod.parameters(), **config) 381 382 # use ones to make sure the single machine model have the same input 383 # on different ranks 384 inp = torch.ones(8, 10, device=self.device_type) 385 self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) 386 387 @with_comms 388 def test_adadelta_1d_sharding(self): 389 mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 390 391 adadelta_configs = [ 392 {"lr": 0.1, "foreach": False}, 393 {"lr": 0.1, "rho": 0.85, "foreach": False}, 394 {"lr": 0.1, "rho": 0.88, "eps": 1e-5, "foreach": False}, 395 { 396 "lr": 0.1, 397 "rho": 0.88, 398 "eps": 1e-6, 399 "weight_decay": 0.05, 400 "foreach": False, 401 }, 402 { 403 "lr": 0.1, 404 "rho": 0.88, 405 "eps": 1e-6, 406 "weight_decay": 0.05, 407 }, 408 { 409 "lr": 0.1, 410 "rho": 0.88, 411 "eps": 1e-6, 412 "weight_decay": 0.05, 413 "maximize": True, 414 }, 415 ] 416 417 for config in adadelta_configs: 418 mod = MLPModule(self.device_type) 419 opt = torch.optim.Adadelta(mod.parameters(), **config) 420 421 dist_mod = distribute_module( 422 deepcopy(mod), mesh, shard_fn, input_fn, output_fn 423 ) 424 dist_opt = torch.optim.Adadelta(dist_mod.parameters(), **config) 425 426 # use ones to make sure the single machine model have the same input 427 # on different ranks 428 inp = torch.ones(8, 10, device=self.device_type) 429 self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) 430 431 @with_comms 432 def test_nadam_1d_sharding(self): 433 mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 434 435 nadam_configs = [ 436 {"lr": 0.1, "foreach": False}, 437 {"lr": 0.1, "weight_decay": 0.05, "foreach": False}, 438 {"lr": 0.1, "weight_decay": 0.05}, 439 { 440 "lr": 0.1, 441 "betas": (0.6, 0.66), 442 "eps": 1e-6, 443 "weight_decay": 0.05, 444 }, 445 { 446 "lr": 0.1, 447 "betas": (0.6, 0.66), 448 "eps": 1e-6, 449 "weight_decay": 0.05, 450 "decoupled_weight_decay": True, 451 }, 452 ] 453 454 for config in nadam_configs: 455 mod = MLPModule(self.device_type) 456 opt = torch.optim.NAdam(mod.parameters(), **config) 457 458 dist_mod = distribute_module( 459 deepcopy(mod), mesh, shard_fn, input_fn, output_fn 460 ) 461 dist_opt = torch.optim.NAdam(dist_mod.parameters(), **config) 462 463 # use ones to make sure the single machine model have the same input 464 # on different ranks 465 inp = torch.ones(8, 10, device=self.device_type) 466 self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) 467 468 @with_comms 469 def test_radam_1d_sharding(self): 470 mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 471 472 radam_configs = [ 473 {"lr": 0.1, "foreach": False}, 474 {"lr": 0.1, "weight_decay": 0.05, "foreach": False}, 475 { 476 "lr": 0.1, 477 "weight_decay": 0.05, 478 }, 479 { 480 "lr": 0.1, 481 "betas": (0.6, 0.66), 482 "eps": 1e-6, 483 "weight_decay": 0.05, 484 }, 485 { 486 "lr": 0.1, 487 "betas": (0.6, 0.66), 488 "eps": 1e-6, 489 "weight_decay": 0.05, 490 "decoupled_weight_decay": True, 491 }, 492 ] 493 494 for config in radam_configs: 495 mod = MLPModule(self.device_type) 496 opt = torch.optim.RAdam(mod.parameters(), **config) 497 498 dist_mod = distribute_module( 499 deepcopy(mod), mesh, shard_fn, input_fn, output_fn 500 ) 501 dist_opt = torch.optim.RAdam(dist_mod.parameters(), **config) 502 503 # use ones to make sure the single machine model have the same input 504 # on different ranks 505 inp = torch.ones(8, 10, device=self.device_type) 506 self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) 507 508 @with_comms 509 def test_adamax_1d_sharding(self): 510 mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 511 512 adamax_configs = [ 513 {"lr": 0.1, "foreach": False}, 514 {"lr": 0.1, "betas": (0.6, 0.66), "foreach": False}, 515 {"lr": 0.1, "betas": (0.6, 0.66), "eps": 1e-6, "foreach": False}, 516 { 517 "lr": 0.1, 518 "betas": (0.6, 0.66), 519 "eps": 1e-6, 520 "weight_decay": 0.05, 521 "foreach": False, 522 }, 523 { 524 "lr": 0.1, 525 "betas": (0.6, 0.66), 526 "eps": 1e-6, 527 "weight_decay": 0.05, 528 }, 529 { 530 "lr": 0.1, 531 "betas": (0.6, 0.66), 532 "eps": 1e-6, 533 "weight_decay": 0.05, 534 "maximize": True, 535 }, 536 ] 537 538 for config in adamax_configs: 539 mod = MLPModule(self.device_type) 540 opt = torch.optim.Adamax(mod.parameters(), **config) 541 542 dist_mod = distribute_module( 543 deepcopy(mod), mesh, shard_fn, input_fn, output_fn 544 ) 545 dist_opt = torch.optim.Adamax(dist_mod.parameters(), **config) 546 547 # use ones to make sure the single machine model have the same input 548 # on different ranks 549 inp = torch.ones(8, 10, device=self.device_type) 550 self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) 551 552 @with_comms 553 def test_asgd_1d_sharding(self): 554 mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 555 556 asgd_configs = [ 557 {"lr": 0.1, "foreach": False}, 558 {"lr": 0.1, "lambd": 0.001, "foreach": False}, 559 {"lr": 0.1, "lambd": 0.001, "alpha": 0.85, "foreach": False}, 560 {"lr": 0.1, "lambd": 0.001, "alpha": 0.85, "t0": 1e5, "foreach": False}, 561 { 562 "lr": 0.1, 563 "lambd": 0.001, 564 "alpha": 0.85, 565 "t0": 1e5, 566 "weight_decay": 0.05, 567 "foreach": False, 568 }, 569 { 570 "lr": 0.1, 571 "lambd": 0.001, 572 "alpha": 0.85, 573 "t0": 1e5, 574 "weight_decay": 0.05, 575 "foreach": True, 576 }, 577 { 578 "lr": 0.1, 579 "lambd": 0.001, 580 "alpha": 0.85, 581 "t0": 1e5, 582 "weight_decay": 0.05, 583 "foreach": True, 584 "maximize": True, 585 }, 586 ] 587 588 for config in asgd_configs: 589 mod = MLPModule(self.device_type) 590 opt = torch.optim.ASGD(mod.parameters(), **config) 591 592 dist_mod = distribute_module( 593 deepcopy(mod), mesh, shard_fn, input_fn, output_fn 594 ) 595 dist_opt = torch.optim.ASGD(dist_mod.parameters(), **config) 596 597 # use ones to make sure the single machine model have the same input 598 # on different ranks 599 inp = torch.ones(8, 10, device=self.device_type) 600 601 # TODO: We want to keep a unit test for ASGD optimizer for the time being, but we need to look into why 602 # when using ASGD we need higher atol and rtol when comparing model parameters. 603 # Default 'rtol' and 'atol' for attr:`~torch.float32` are ``1.3e-6`` and ``1e-5`` 604 # Pointer here: https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L65 605 self._assert_optimizer( 606 mesh, mod, opt, dist_mod, dist_opt, inp, atol=1.3e-5, rtol=1e-4 607 ) 608 609 610if __name__ == "__main__": 611 run_tests() 612