1# Owner(s): ["module: optimizer"] 2import functools 3import math 4import tempfile 5import unittest 6from copy import deepcopy 7from typing import Any, Dict, Tuple 8from unittest.mock import patch 9 10from optim.test_lrscheduler import TestLRScheduler # noqa: F401 11from optim.test_optim import TestDifferentiableOptimizer # noqa: F401 12from optim.test_swa_utils import TestSWAUtils # noqa: F401 13 14import torch 15from torch.nn import Parameter 16from torch.optim import Optimizer, SGD 17from torch.optim.lr_scheduler import ReduceLROnPlateau 18from torch.optim.optimizer import ( 19 register_optimizer_step_post_hook, 20 register_optimizer_step_pre_hook, 21) 22from torch.testing._internal.common_cuda import TEST_MULTIGPU 23from torch.testing._internal.common_device_type import ( 24 instantiate_device_type_tests, 25 largeTensorTest, 26 onlyCPU, 27 onlyCUDA, 28 onlyNativeDeviceTypes, 29 skipMPS, 30 TEST_WITH_ROCM, 31) 32from torch.testing._internal.common_dtype import floating_types_and 33from torch.testing._internal.common_optimizers import ( 34 _get_device_type, 35 _get_optim_inputs_including_global_cliquey_kwargs, 36 optim_db, 37 OptimizerErrorEnum, 38 optims, 39 TensorTracker, 40) 41from torch.testing._internal.common_utils import ( 42 markDynamoStrictTest, 43 parametrize, 44 run_tests, 45 TEST_WITH_TORCHDYNAMO, 46 TestCase, 47) 48 49 50FP16_REDUCED_PRECISION = {"atol": 1e-5, "rtol": 1e-4} 51 52 53def rosenbrock(tensor): 54 assert tensor.size() == torch.Size( 55 [2] 56 ), f"Requires tensor with 2 scalars but got {tensor.size()}" 57 x, y = tensor 58 return (1 - x) ** 2 + 100 * (y - x**2) ** 2 59 60 61def drosenbrock(tensor): 62 assert tensor.size() == torch.Size( 63 [2] 64 ), f"Requires tensor with 2 scalars but got {tensor.size()}" 65 x, y = tensor 66 return torch.stack((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2))) 67 68 69@markDynamoStrictTest 70class TestOptimRenewed(TestCase): 71 """ 72 This test class validates the core optimizers and is structured as the correctness of: 73 - The update algorithms (forloop implementation) 74 * Every optimizer's algorithm is most readably implemented through a big for-loop 75 over all the parameters, which is what we refer to as the forloop or single tensor 76 implementation. These algorithms are manually validated by comparing to the paper 77 and systematically validated by assuring that the loss goes the right direction 78 when the optimizer has been applied. 79 * This implementation should compose with optimizer hyperparameters well, such as 80 supporting Tensor LRs, the capturable API, and sparse and complex parameters. 81 - Each varying implementation 82 * We then have implementations that improve upon the performance of the forloop 83 implementation by leveraging fusion, namely our foreach (mult_tensor) and fused 84 implementations. 85 * These variations are validated numerically by comparing with the forloop version 86 of the optimizer. In fact, we test most variations this way--we see the forloop 87 implementation as the ground truth and expect that improvements to it in any way 88 should be just as correct. 89 * Both params and optimizer states should be validated numerically. 90 - state_dict APIs 91 * The optimizer instance should be serializable 92 * Calling save and load should be deterministic 93 * Moving between devices should be seamless 94 * BC - load_state_dict should be able to handle older optimizer states 95 - Hook APIs (everything should fire in the right order) 96 - LR Scheduler integration (composing should not error + should go the right direction) 97 - Parameter groups (should be equivalent to having multiple optimizers) 98 - Erroring (what should error should error) 99 100 We also cover different ways of generating parameters and grads: 101 - With parameters, we either generate them randomly given specific shapes or we take 102 them from a sample NN module. 103 * Variety is important here because NN modules have type Parameter and randomly 104 generated tensors have type Tensor. 105 * Parameters can be sparse for a subset of the optimizers (check out OptimizerInfo) 106 * Complex parameters should be handled using view_as_real 107 * Parameters can be spread across different devices and different dtypes for any 108 given optimizer 109 * Parameters can be contiguous and noncontiguous 110 - With grads, we follow suit from the parameters. 111 * Grads can also be None, empty, or zero-valued, and this should not disrupt training. 112 """ 113 114 @onlyCPU 115 @optims(optim_db) 116 def test_optim_infos_do_not_specify_global_cliquey_kwargs( 117 self, device, dtype, optim_info 118 ): 119 global_cliquey_flags = ["foreach", "fused", "differentiable"] 120 for optim_input in optim_info.optim_inputs_func(device=device): 121 self.assertFalse( 122 any(f for f in global_cliquey_flags if f in optim_input.kwargs) 123 ) 124 125 @optims([optim for optim in optim_db if optim.optim_error_inputs_func is not None]) 126 def test_errors(self, device, dtype, optim_info): 127 optim_cls = optim_info.optim_cls 128 error_inputs = optim_info.optim_error_inputs_func(device=device, dtype=dtype) 129 130 for error_input in error_inputs: 131 optim_input = error_input.optimizer_error_input 132 params, kwargs = optim_input.params, optim_input.kwargs 133 if error_input.error_on == OptimizerErrorEnum.CONSTRUCTION_ERROR: 134 if issubclass(error_input.error_type, Warning): 135 with self.assertWarnsRegex( 136 error_input.error_type, error_input.error_regex 137 ): 138 optim_cls(params, **kwargs) 139 else: 140 with self.assertRaisesRegex( 141 error_input.error_type, error_input.error_regex 142 ): 143 optim_cls(params, **kwargs) 144 elif error_input.error_on == OptimizerErrorEnum.STEP_ERROR: 145 optim = optim_cls(params, **kwargs) 146 if issubclass(error_input.error_type, Warning): 147 with self.assertWarnsRegex( 148 error_input.error_type, error_input.error_regex 149 ): 150 optim.step() 151 else: 152 with self.assertRaisesRegex( 153 error_input.error_type, error_input.error_regex 154 ): 155 optim.step() 156 else: 157 raise NotImplementedError(f"Unknown error type {error_input.error_on}") 158 159 @parametrize("contiguous", [True, False]) 160 @parametrize("with_lrsched", [True, False]) 161 @optims(optim_db, dtypes=[torch.float32]) 162 def test_forloop_goes_right_direction( 163 self, device, dtype, optim_info, contiguous, with_lrsched 164 ): 165 optim_cls = optim_info.optim_cls 166 schedulers_constructors = ( 167 optim_info.scheduler_inputs if with_lrsched else [None] 168 ) 169 170 for schedulers_constructor in schedulers_constructors: 171 # with tensor LR we need fresh inputs for each scheduler 172 # or mutating it will carry across iters 173 optim_inputs = optim_info.optim_inputs_func(device=device) 174 for optim_input in optim_inputs: 175 if "foreach" in optim_info.supported_impls: 176 optim_input.kwargs["foreach"] = False # force forloop 177 if contiguous: 178 weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) 179 bias = Parameter(torch.randn((10), device=device, dtype=dtype)) 180 else: 181 weight = Parameter( 182 torch.randn((10, 5, 2), device=device, dtype=dtype)[..., 0] 183 ) 184 bias = Parameter( 185 torch.randn((10, 2), device=device, dtype=dtype)[..., 0] 186 ) 187 input = torch.randn(5, device=device, dtype=dtype) 188 189 optimizer = optim_cls([weight, bias], **optim_input.kwargs) 190 schedulers = [ 191 s(optimizer) 192 for s in (schedulers_constructor if schedulers_constructor else []) 193 ] 194 195 def closure(): 196 optimizer.zero_grad() 197 loss = (weight.mv(input) + bias).pow(2).sum() 198 loss.backward() 199 if optim_info.only_supports_sparse_grads: 200 # For this test, we naively convert the Tensor layout, which we know does 201 # NOT represent the expected use case for optims like SparseAdam! 202 weight.grad = weight.grad.to_sparse() 203 bias.grad = bias.grad.to_sparse() 204 return loss 205 206 initial_value = closure().item() 207 for _ in range(20): 208 if optim_info.step_requires_closure: 209 loss = optimizer.step(closure) 210 else: 211 loss = closure() 212 optimizer.step() 213 214 for scheduler in schedulers: 215 if isinstance(scheduler, ReduceLROnPlateau): 216 scheduler.step(loss) 217 else: 218 scheduler.step() 219 220 if optim_input.kwargs.get("maximize", False): 221 self.assertGreater(closure().item(), initial_value) 222 else: 223 self.assertLess(closure().item(), initial_value) 224 225 @onlyCUDA 226 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 227 @parametrize("with_lrsched", [True, False]) 228 @optims(optim_db, dtypes=[torch.float32]) 229 def test_forloop_goes_right_direction_multigpu( 230 self, device, dtype, optim_info, with_lrsched 231 ): 232 optim_cls = optim_info.optim_cls 233 schedulers_constructors = ( 234 optim_info.scheduler_inputs if with_lrsched else [None] 235 ) 236 for schedulers_constructor in schedulers_constructors: 237 # We need a fresh set of inputs if we have a tensor LR 238 # to not carry mutations across iterations. 239 optim_inputs = optim_info.optim_inputs_func(device=device) 240 for optim_input in optim_inputs: 241 if "foreach" in optim_info.supported_impls: 242 optim_input.kwargs["foreach"] = False # force forloop 243 244 weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype)) 245 bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype)) 246 inpt = torch.randn(5, device="cuda:0", dtype=dtype) 247 248 optimizer = optim_cls([weight, bias], **optim_input.kwargs) 249 schedulers = [ 250 s(optimizer) 251 for s in (schedulers_constructor if schedulers_constructor else []) 252 ] 253 254 def closure(): 255 optimizer.zero_grad() 256 loss = (weight.mv(inpt).cuda(1) + bias).pow(2).sum() 257 loss.backward() 258 if optim_info.only_supports_sparse_grads: 259 # For this test, we naively convert the Tensor layout, which we know does 260 # NOT represent the expected use case for optims like SparseAdam! 261 weight.grad = weight.grad.to_sparse() 262 bias.grad = bias.grad.to_sparse() 263 return loss 264 265 initial_value = closure().item() 266 for _ in range(20): 267 loss = optimizer.step(closure) 268 for scheduler in schedulers: 269 if isinstance(scheduler, ReduceLROnPlateau): 270 scheduler.step(loss) 271 else: 272 scheduler.step() 273 274 if optim_input.kwargs.get("maximize", False): 275 self.assertGreater(closure().item(), initial_value) 276 else: 277 self.assertLess(closure().item(), initial_value) 278 279 @optims(optim_db, dtypes=[torch.float32]) 280 def test_param_group_with_lrscheduler_goes_right_direction( 281 self, device, dtype, optim_info 282 ): 283 optim_cls = optim_info.optim_cls 284 285 for schedulers_c in optim_info.scheduler_inputs: 286 weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) 287 bias = Parameter(torch.randn((10), device=device, dtype=dtype)) 288 inpt = torch.randn(5, device=device, dtype=dtype) 289 290 # avoid endless recompiles by wrapping LR in a tensor if we're compiling 291 lr = torch.tensor(0.01) if torch._utils.is_compiling() else 0.01 292 optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": lr}]) 293 schedulers = [scheduler_c(optimizer) for scheduler_c in schedulers_c] 294 295 def closure(): 296 optimizer.zero_grad() 297 loss = (weight.mv(inpt) + bias).pow(2).sum() 298 loss.backward() 299 if optim_info.only_supports_sparse_grads: 300 # For this test, we naively convert the Tensor layout, which we know does 301 # NOT represent the expected use case for optims like SparseAdam! 302 weight.grad = weight.grad.to_sparse() 303 bias.grad = bias.grad.to_sparse() 304 return loss 305 306 initial_value = closure().item() 307 for _ in range(20): 308 loss = optimizer.step(closure) 309 for scheduler in schedulers: 310 if isinstance(scheduler, ReduceLROnPlateau): 311 scheduler.step(loss) 312 else: 313 scheduler.step() 314 315 self.assertLess(closure().item(), initial_value) 316 317 @optims(optim_db, dtypes=[torch.float32]) 318 def test_tensor_lr(self, device, dtype, optim_info): 319 optim_cls = optim_info.optim_cls 320 321 # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 322 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 323 device, dtype, optim_info, skip=("differentiable",) 324 ) 325 for optim_input in all_optim_inputs: 326 weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) 327 weight_c = weight.clone().detach().requires_grad_(True) 328 bias = Parameter(torch.randn((10), device=device, dtype=dtype)) 329 bias_c = bias.clone().detach().requires_grad_(True) 330 inpt = torch.randn(5, device=device, dtype=dtype) 331 332 kwargs = optim_input.kwargs 333 if "lr" in kwargs: 334 del kwargs["lr"] 335 336 kwargs["lr"] = 1.0 if optim_info.step_requires_closure else 1e-3 337 optimizer_r = optim_cls([weight, bias], **kwargs) 338 339 try: 340 kwargs["lr"] = torch.tensor(kwargs["lr"]) 341 optimizer = optim_cls([weight_c, bias_c], **kwargs) 342 except ValueError as e: 343 self.assertRegex(str(e), ".*lr as a Tensor is not supported.*") 344 continue 345 346 def closure(optim, w, b, i): 347 optim.zero_grad() 348 loss = (w.mv(i) + b).pow(2).sum() 349 loss.backward() 350 if optim_info.only_supports_sparse_grads: 351 # For this test, we naively convert the Tensor layout, which we know does 352 # NOT represent the expected use case for optims like SparseAdam! 353 w.grad = w.grad.to_sparse() 354 b.grad = b.grad.to_sparse() 355 return loss 356 357 for _ in range(5): 358 if optim_info.step_requires_closure: 359 optimizer_r.step( 360 functools.partial(closure, optimizer_r, weight, bias, inpt) 361 ) 362 optimizer.step( 363 functools.partial(closure, optimizer, weight_c, bias_c, inpt) 364 ) 365 else: 366 closure(optimizer_r, weight, bias, inpt) 367 closure(optimizer, weight_c, bias_c, inpt) 368 369 self.assertEqual(weight, weight_c) 370 self.assertEqual(bias, bias_c) 371 372 @parametrize("with_lrsched", [True, False]) 373 @optims( 374 [o for o in optim_db if o.supports_sparse or o.only_supports_sparse_grads], 375 dtypes=[torch.float64], 376 ) 377 def test_rosenbrock_sparse(self, device, dtype, optim_info, with_lrsched): 378 optim_cls = optim_info.optim_cls 379 380 # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 381 # Fused impls do not support sparse gradients 382 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 383 device, dtype, optim_info, skip=("differentiable", "fused") 384 ) 385 kwarg_updates, schedulers_constructors = optim_info.metadata_for_sparse 386 387 if with_lrsched and len(schedulers_constructors) == 0: 388 return 389 390 supported_inputs = [] 391 if len(kwarg_updates) != 0: 392 seen = set() 393 for i in all_optim_inputs: 394 for k in kwarg_updates: 395 if k in i.kwargs: 396 del i.kwargs[k] 397 hashable_kwargs = tuple(sorted(i.kwargs.items())) 398 if len(i.kwargs) > 0 and hashable_kwargs not in seen: 399 supported_inputs.append(i) 400 seen.add(hashable_kwargs) 401 if "lr" in kwarg_updates: 402 i.kwargs["lr"] = kwarg_updates["lr"] 403 else: 404 supported_inputs = all_optim_inputs 405 406 for optim_input in supported_inputs: 407 kwargs = optim_input.kwargs 408 multi_tensor = kwargs.get("foreach", False) 409 410 # For rosenbrock tests, it is mandated that the param is a tensor with 2 numbers 411 if multi_tensor: 412 params_t = [ 413 torch.tensor([1.5, 1.5]), 414 torch.tensor([1.5, 1.5], dtype=dtype), 415 ] 416 else: 417 params_t = [torch.tensor([1.5, 1.5])] 418 419 params = [Parameter(param_t) for param_t in params_t] 420 optimizer = optim_cls(params, **kwargs) 421 schedulers = [ 422 s(optimizer) for s in (schedulers_constructors if with_lrsched else []) 423 ] 424 425 if not optim_info.only_supports_sparse_grads: 426 params_c = [Parameter(param_t.clone()) for param_t in params_t] 427 optimizer_c = optim_cls(params_c, **kwargs) 428 schedulers_c = [ 429 s(optimizer_c) 430 for s in (schedulers_constructors if with_lrsched else []) 431 ] 432 433 solution = torch.tensor([1, 1]) 434 with torch.no_grad(): 435 initial_dist = sum(param.dist(solution) for param in params) 436 437 def get_grad(param, sparse_grad, w): 438 grad = drosenbrock(param) 439 # NB: We torture test the optimizer by returning an 440 # uncoalesced sparse tensor 441 442 # Depending on w, provide only the x or y gradient 443 if sparse_grad: 444 if w: 445 i = torch.tensor([[0, 0]], dtype=torch.int64) 446 x = grad[0] 447 v = torch.tensor([x / 4.0, x - x / 4.0]) 448 else: 449 i = torch.tensor([[1, 1]], dtype=torch.int64) 450 y = grad[1] 451 v = torch.tensor([y - y / 4.0, y / 4.0]) 452 grad_out = torch.sparse_coo_tensor(i, v, (2,), dtype=v.dtype) 453 else: 454 if w: 455 grad_out = torch.tensor([grad[0], 0], dtype=param.dtype) 456 else: 457 grad_out = torch.tensor([0, grad[1]], dtype=param.dtype) 458 return grad_out 459 460 def eval(params, sparse_grad, w): 461 optimizer.zero_grad() 462 if multi_tensor: 463 loss = sum(rosenbrock(param) for param in params) 464 else: 465 loss = rosenbrock(params[0]) 466 loss.backward() 467 468 grads_out = [get_grad(param, sparse_grad, w) for param in params] 469 with torch.no_grad(): 470 params[0].grad = grads_out[0] 471 if multi_tensor: 472 params[1].grad = grads_out[1].to(dtype=dtype) 473 return loss 474 475 for i in range(1800): 476 # Do cyclic coordinate descent 477 w = i % 2 478 optimizer.step(functools.partial(eval, params, True, w)) 479 for scheduler in schedulers: 480 if isinstance(scheduler, ReduceLROnPlateau): 481 scheduler.step(rosenbrock(params[0])) 482 else: 483 scheduler.step() 484 if not optim_info.only_supports_sparse_grads: 485 optimizer_c.step(functools.partial(eval, params_c, False, w)) 486 for scheduler in schedulers_c: 487 if isinstance(scheduler, ReduceLROnPlateau): 488 scheduler.step(rosenbrock(params_c[0])) 489 else: 490 scheduler.step() 491 # Tolerance is increased due to floating point error from different 492 # code path for dense case: x v.s. x - x / 4.0 + x / 4.0 493 self.assertEqual(params, params_c, atol=5e-6, rtol=5e-6) 494 495 if not kwargs.get("maximize", False): 496 self.assertLessEqual( 497 sum(param.dist(solution) for param in params), initial_dist 498 ) 499 else: 500 self.assertGreaterEqual( 501 sum(rosenbrock(param) for param in params), 502 sum(rosenbrock(param_t) for param_t in params_t), 503 ) 504 505 @skipMPS 506 @optims([o for o in optim_db if o.supports_complex], dtypes=[torch.complex64]) 507 def test_complex(self, device, dtype, optim_info): 508 optim_cls = optim_info.optim_cls 509 # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 510 # Also skip fused, since our fused kernels do not support complex 511 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 512 device, dtype, optim_info, skip=("differentiable", "fused") 513 ) 514 for optim_input in all_optim_inputs: 515 # Last param is intentionally real to test that we can mix real and complex 516 complex_params = [ 517 torch.randn(10, 5, device=device, dtype=dtype, requires_grad=True), 518 torch.randn(10, device=device, dtype=dtype, requires_grad=True), 519 torch.randn( 520 10, 5, device=device, dtype=torch.float32, requires_grad=True 521 ), 522 ] 523 real_params = [ 524 ( 525 torch.view_as_real(param).detach().clone().requires_grad_() 526 if param.is_complex() 527 else param.detach().clone().requires_grad_() 528 ) 529 for param in complex_params 530 ] 531 532 complex_optimizer = optim_cls(complex_params, **optim_input.kwargs) 533 real_optimizer = optim_cls(real_params, **optim_input.kwargs) 534 real_steps = [] 535 complex_steps = [] 536 grads_losses = [] 537 538 def real_closure(): 539 for param in real_params: 540 grad = torch.randn_like(param) 541 param.grad = grad 542 real_steps.append(param.detach().clone()) 543 grads_losses.append(grad.clone()) 544 loss = torch.randn(1) 545 grads_losses.append(loss.clone()) 546 return loss 547 548 def complex_closure(): 549 for param in complex_params: 550 if torch.is_complex(param): 551 grad = torch.view_as_complex(grads_losses.pop(0)) 552 complex_steps.append(torch.view_as_real_copy(param.detach())) 553 else: 554 grad = grads_losses.pop(0) 555 complex_steps.append(param.detach().clone()) 556 param.grad = grad 557 return grads_losses.pop(0) 558 559 for _ in range(3): 560 if optim_info.step_requires_closure: 561 # LBFGS, for example, requires closure and calls it internally 562 real_optimizer.step(real_closure) 563 complex_optimizer.step(complex_closure) 564 else: 565 # For other optimizers, we call closure explicitly to set the gradients 566 real_closure() 567 complex_closure() 568 real_optimizer.step() 569 complex_optimizer.step() 570 571 # Final Parameters should be the same 572 complex_params_asreal = [ 573 torch.view_as_real(param) if param.is_complex() else param 574 for param in complex_params 575 ] 576 self.assertEqual(real_params, complex_params_asreal) 577 578 # All intermediate steps should also be the same 579 # also checks steps taken within for example a line search 580 self.assertEqual(complex_steps, real_steps) 581 582 @skipMPS 583 @optims([o for o in optim_db if o.supports_complex], dtypes=[torch.complex64]) 584 def test_complex_2d(self, device, dtype, optim_info): 585 optim_cls = optim_info.optim_cls 586 # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 587 # Also skip fused, since our fused kernels do not support complex 588 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 589 device, dtype, optim_info, skip=("differentiable", "fused") 590 ) 591 for optim_input in all_optim_inputs: 592 if optim_info.step_requires_closure: 593 # Why? The way we implement complex is by turning complex params into view_as_real 594 # alternatives. For example, an size (M,N) tensor will become (M,N,2). In this test, 595 # we break apart a tensor into its real and imaginary parts, which would be 2x(M,N). 596 # For other pointwise optimizers, this distinction is trivial, but for LBFGS where 597 # there are reductions across all parameters (and all the grads get flattened into 598 # one long Tensor), this ordering matters. Why? Reductions are not deterministic 599 # because addition between floating point numbers is not associative, i.e., 600 # a + b + c != a + c + b. Thus, we add a seed here to control the discrepancy that 601 # will happen with LBFGS. Note that in test_complex above, there is no need for a seed 602 # nor for increased tolerance, because results should be bitwise equivalent. 603 torch.manual_seed(2024) 604 605 a1 = torch.randn(2, device=device, dtype=dtype, requires_grad=True) 606 a1_real = a1.real.clone().detach() 607 a1_imag = a1.imag.clone().detach() 608 a1_real.requires_grad_() 609 a1_imag.requires_grad_() 610 optim1 = optim_cls([a1], **optim_input.kwargs) 611 optim2 = optim_cls([a1_real, a1_imag], **optim_input.kwargs) 612 613 a1_reals = TensorTracker() 614 a1_imags = TensorTracker() 615 a1_grad_reals = TensorTracker() 616 a1_grad_imags = TensorTracker() 617 losses = TensorTracker() 618 619 def closure1(): 620 optim1.zero_grad() 621 loss = rosenbrock(a1).abs() 622 loss.backward() 623 624 # Track clones to best test accuracy 625 a1_reals.add(a1.real) 626 a1_imags.add(a1.imag) 627 a1_grad_reals.add(a1.grad.real) 628 a1_grad_imags.add(a1.grad.imag) 629 630 losses.add(loss) 631 632 return loss 633 634 def closure2(): 635 optim2.zero_grad() 636 a1_reals.pop_check_set(a1_real, self) 637 a1_imags.pop_check_set(a1_imag, self) 638 a2 = torch.complex(a1_real, a1_imag) 639 loss = rosenbrock(a2).abs() 640 losses.pop_check_set(loss, self) 641 loss.backward() 642 a1_grad_reals.pop_check_set(a1_real.grad, self) 643 a1_grad_imags.pop_check_set(a1_imag.grad, self) 644 return loss 645 646 for _ in range(3): 647 if optim_info.step_requires_closure: 648 # LBFGS, for example, requires closure and calls it internally 649 optim1.step(closure1) 650 optim2.step(closure2) 651 else: 652 closure1() 653 closure2() 654 optim1.step() 655 optim2.step() 656 657 self.assertEqual(a1.real, a1_real) 658 self.assertEqual(a1.imag, a1_imag) 659 660 self.assertTrue(a1_reals.all_popped()) 661 self.assertTrue(a1_imags.all_popped()) 662 self.assertTrue(a1_grad_reals.all_popped()) 663 self.assertTrue(a1_grad_imags.all_popped()) 664 self.assertTrue(losses.all_popped()) 665 666 def _compare_between( 667 self, inputs, models, optimizers, assert_eq_kwargs=None, assert_step_dtype=None 668 ): 669 # why 7? iteration 7 is where we start to see differences for RAdam 670 # params interacting with the small eps value, because that's right 671 # after rho_t becomes greater than 5 in step 6. 672 if assert_eq_kwargs is None: 673 assert_eq_kwargs = {} 674 kIterations = 7 675 tracker = TensorTracker(assert_eq_kwargs) 676 for i in range(kIterations): 677 state, updated_params = [], [] 678 if not isinstance(inputs, list): 679 inputs = [inputs, inputs] 680 for input, model, optimizer in zip(inputs, models, optimizers): 681 optimizer.zero_grad() 682 683 if i == 3: 684 # Freeze a layer to test if the step of this layer in 'fused' or 'foreach' 685 # is same as the step in 'forloop'. 686 model[2].requires_grad_(False) 687 if i == 5: 688 # Unfreeze the layer after 2 iters. 689 model[2].requires_grad_(True) 690 691 # Test that step behaves as expected (a no-op) when grads are set to None 692 if i != 2: 693 output = model(input) 694 loss = output.sum() 695 loss.backward() 696 697 optimizer.step() 698 state.append(optimizer.state) 699 updated_params.append(model.parameters()) 700 701 og_state, new_state = state 702 for og_p, new_p in zip(updated_params[0], updated_params[1]): 703 tracker.add(og_p) 704 tracker.pop_check_set(new_p, self) 705 706 # check that optimizer states are the same 707 og_p_state = og_state[og_p] 708 new_p_state = new_state[new_p] 709 if assert_step_dtype is not None: 710 if torch.is_tensor(og_p_state.get("step", None)): 711 self.assertEqual(og_p_state["step"].dtype, assert_step_dtype) 712 if torch.is_tensor(new_p_state.get("step", None)): 713 self.assertEqual(new_p_state["step"].dtype, assert_step_dtype) 714 for k in og_p_state: 715 tracker.add(og_p_state[k]) 716 tracker.pop_check_set(new_p_state[k], self) 717 718 self.assertTrue(tracker.all_popped()) 719 720 def _test_derived_optimizers( 721 self, 722 device, 723 dtype, 724 optim_info, 725 flag, 726 reduced_precision=False, 727 assert_step_dtype=None, 728 ): 729 """ 730 Given a flag 'fused' or 'foreach', test for parity of optimizer state 731 and updated parameters between when the flag is set to True and False 732 for provided optimizer configurations. 733 """ 734 assert flag in ("foreach", "fused") 735 assert_eq_kwargs = {} if not reduced_precision else FP16_REDUCED_PRECISION 736 737 optim_inputs = optim_info.optim_inputs_func(device=device, dtype=dtype) 738 optim_cls = optim_info.optim_cls 739 for optim_input in optim_inputs: 740 models, optimizers = [], [] 741 kwargs = deepcopy(optim_input.kwargs) 742 if kwargs.get("capturable", False) and _get_device_type(device) == "cpu": 743 # capturable is not supported on CPU 744 continue 745 for flag_value in (False, True): 746 kwargs[flag] = flag_value 747 input = torch.tensor( 748 [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=device 749 ).reshape(3, 2) 750 751 torch.manual_seed(1) 752 model = torch.nn.Sequential( 753 torch.nn.Linear(2, 3), 754 torch.nn.Sigmoid(), 755 torch.nn.Linear(3, 1), 756 torch.nn.Sigmoid(), 757 ) 758 model.to(dtype=dtype, device=device) 759 760 # foreach/fused optimizers should be tested with a 761 # zero_size tensor as its last param. 762 # ref: https://github.com/pytorch/pytorch/issues/100701 763 empty_param = torch.empty( 764 (), device=device, dtype=dtype, requires_grad=True 765 ) 766 empty_param.grad = torch.rand_like(empty_param) 767 params = list(model.parameters()) + [empty_param] 768 769 optimizer = optim_cls(params, **kwargs) 770 models.append(model) 771 optimizers.append(optimizer) 772 773 self._compare_between( 774 input, models, optimizers, assert_eq_kwargs, assert_step_dtype 775 ) 776 777 @skipMPS # MPS doesn't support torch.float64, see https://github.com/pytorch/pytorch/issues/115350 778 @optims( 779 [optim for optim in optim_db if "foreach" in optim.supported_impls], 780 dtypes=[torch.float64], 781 ) 782 def test_foreach_matches_forloop(self, device, dtype, optim_info): 783 self._test_derived_optimizers(device, dtype, optim_info, "foreach") 784 785 @onlyCUDA 786 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 787 @parametrize("impl", ["foreach", "fused"]) 788 @optims( 789 [ 790 optim 791 for optim in optim_db 792 if "foreach" in optim.supported_impls or "fused" in optim.supported_impls 793 ] 794 ) 795 def test_mixed_device_dtype(self, device, dtype, optim_info, impl): 796 """ 797 Similar in essence to _test_derived_optimizers above. The main difference is that 798 _test_derived_optimizers uses model parameters whereas we randomly pass in 799 parameters of different dtypes and devices here. We need multiple GPUs (vs just a 800 CPU and GPU) because fused adam only works on GPUs. (Thus we only run the tests 801 that call into this helper when TEST_MULTIGPU.) 802 """ 803 assert impl in ("foreach", "fused") 804 if impl == "foreach" and "foreach" not in optim_info.supported_impls: 805 return unittest.skip( 806 f"foreach not supported for {optim_info.optim_cls.__name__}" 807 ) 808 elif impl == "fused" and "cuda" not in optim_info.supports_fused_on: 809 return unittest.skip( 810 f"fused not supported for {optim_info.optim_cls.__name__} on cuda" 811 ) 812 813 params = [ 814 torch.rand(2, 3, dtype=torch.float64, device="cuda:0", requires_grad=True), 815 torch.rand(2, 3, dtype=torch.float32, device="cuda:0", requires_grad=True), 816 torch.rand(2, 3, dtype=torch.float16, device="cuda:0", requires_grad=True), 817 torch.rand(2, 3, dtype=torch.bfloat16, device="cuda:0", requires_grad=True), 818 torch.rand(2, 3, dtype=torch.float64, device="cuda:1", requires_grad=True), 819 torch.rand(2, 3, dtype=torch.float32, device="cuda:1", requires_grad=True), 820 torch.rand(2, 3, dtype=torch.float16, device="cuda:1", requires_grad=True), 821 torch.rand(2, 3, dtype=torch.bfloat16, device="cuda:1", requires_grad=True), 822 torch.randint( 823 1024, (2, 3), dtype=torch.int64, device="cuda:1", requires_grad=False 824 ), 825 ] 826 827 for p in params: 828 if p.requires_grad: 829 p.grad = torch.rand_like(p, device=p.device, dtype=p.dtype) 830 831 kIterations = 7 if impl == "foreach" else 1 832 optim_inputs = optim_info.optim_inputs_func(device=device) 833 optim_cls = optim_info.optim_cls 834 for optim_input in optim_inputs: 835 updated_params, state = [], [] 836 kwargs = deepcopy(optim_input.kwargs) 837 if kwargs.get("capturable", False) and _get_device_type(device) == "cpu": 838 # capturable is not supported on CPU 839 continue 840 for use_impl in (False, True): 841 kwargs[impl] = use_impl 842 params_clone = [] 843 for p in params: 844 p_clone = p.clone().detach() 845 if p.requires_grad: 846 p_clone.requires_grad = True 847 p_clone.grad = p.grad.clone().detach() 848 params_clone.append(p_clone) 849 850 optimizer = optim_cls(params_clone, **kwargs) 851 for _ in range(kIterations): 852 optimizer.step() 853 854 state.append(optimizer.state) 855 updated_params.append(params_clone) 856 857 og_state, new_state = state 858 for og_p, new_p in zip(updated_params[0], updated_params[1]): 859 # Increasing the tolerance as we are collating lots of ops together for optimizers and 860 # the designated tolerances are for single op only. 861 single_rtol, single_atol = torch.testing._comparison.get_tolerances( 862 new_p.dtype, rtol=None, atol=None 863 ) 864 rtol = 5 * single_rtol 865 atol = 5 * single_atol 866 867 self.assertEqual(og_p, new_p, rtol=rtol, atol=atol) 868 869 # check that optimizer states are the same 870 og_p_state = og_state[og_p] 871 new_p_state = new_state[new_p] 872 873 for k in og_p_state: 874 actual = new_p_state[k] 875 self.assertEqual(og_p_state[k], actual, rtol=rtol, atol=atol) 876 877 @onlyCUDA 878 @optims( 879 [optim for optim in optim_db if "foreach" in optim.supported_impls], 880 dtypes=[torch.float64], 881 ) 882 def test_set_default_dtype_works_with_foreach(self, device, dtype, optim_info): 883 # https://github.com/pytorch/pytorch/issues/110940 884 # We coerce step to always be float32 unless the 885 # default dtype is higher prec float64 886 old_default_dtype = torch.get_default_dtype() 887 for default_dtype in [torch.float64, torch.float16]: 888 try: 889 torch.set_default_dtype(default_dtype) 890 self._test_derived_optimizers( 891 device, 892 dtype, 893 optim_info, 894 "foreach", 895 reduced_precision=default_dtype == torch.float16, 896 assert_step_dtype=( 897 torch.float64 898 if default_dtype == torch.float64 899 else torch.float32 900 ), 901 ) 902 finally: 903 torch.set_default_dtype(old_default_dtype) 904 905 @onlyCUDA 906 @largeTensorTest("72GB", "cuda") 907 @optims( 908 [optim for optim in optim_db if "foreach" in optim.supported_impls], 909 dtypes=[torch.float16], 910 ) 911 def test_foreach_large_tensor(self, device, dtype, optim_info): 912 optim_cls = optim_info.optim_cls 913 optim_inputs = optim_info.optim_inputs_func(device=device) 914 for optim_input in optim_inputs: 915 params = [torch.ones(2**32, device=device, dtype=dtype)] 916 params[0].grad = torch.zeros_like(params[0]) 917 optimizer = optim_cls(params, foreach=True, **optim_input.kwargs) 918 optimizer.step() 919 920 @onlyCUDA 921 @optims( 922 [optim for optim in optim_db if "foreach" in optim.supported_impls], 923 dtypes=[torch.float32], 924 ) 925 def test_peak_memory_foreach(self, device, dtype, optim_info): 926 nparams = 10 927 optim_inputs = optim_info.optim_inputs_func(device=device) 928 optim_cls = optim_info.optim_cls 929 for optim_input in optim_inputs: 930 kwargs = deepcopy(optim_input.kwargs) 931 max_mems = [] 932 for flag_value in (False, True): 933 kwargs["foreach"] = flag_value 934 # The 16 * 8 = 128 is critical here! Our CUDACachingAllocator allocates in blocks 935 # of 512, meaning any tensor that occupies <512 bytes of memory will allocate a 936 # whole 512 bytes anyway. We use 128 (cuz datasize would be 4 bytes) so that param 937 # is size 512 exactly, making our later calculations for intermediate_size easy. 938 param = torch.rand(16, 8, device=device, dtype=dtype) 939 params = [torch.rand_like(param) for _ in range(nparams)] 940 941 optimizer = optim_cls(params, **kwargs) 942 943 for p in params: 944 p.grad = torch.rand_like(p) 945 946 optimizer.step() 947 import gc 948 949 gc.collect() 950 torch.cuda.reset_peak_memory_stats() 951 optimizer.step() 952 gc.collect() 953 max_mems.append(torch.cuda.max_memory_allocated()) 954 955 st_max_mem, mt_max_mem = max_mems 956 intermediate_size = nparams * param.nelement() * param.element_size() 957 nintermediates = 1 # we expect a budget of 1 intermediate most of the time 958 959 # Check the param group directly to handle if the compiler set capturable 960 if optimizer.param_groups[0].get( 961 "capturable", False 962 ) or optim_cls.__name__ in ["Adadelta", "ASGD", "RAdam"]: 963 # with capturable in Adam(W), we have 2 extra intermediates for the bias_corrections 964 # with Adadelta, we have 2 extra for (acc_delta + eps) and (square_avg + eps) 965 # ASGD allocates axs, 2x mus, 2x etas, and grads at the same time 966 nintermediates = 3 967 if optim_cls.__name__ == "NAdam": 968 # with capturable in NAdam, we have 3 extra intermediates for the 969 # bias_correction, mus, and mu_nexts 970 if TEST_WITH_TORCHDYNAMO: 971 # With dynamo, the eager/FX backend appears to hold memory longer than 972 # vanilla eager: https://github.com/pytorch/pytorch/issues/125511 973 nintermediates = 8 974 else: 975 nintermediates = 5 976 977 if optim_cls.__name__ == "RAdam": 978 # RAdam has four intermediates with capturable 979 # num, unrect_step_size, buffer, grouped_grads 980 if TEST_WITH_TORCHDYNAMO: 981 # With dynamo, the eager/FX backend appears to hold memory than 982 # vanilla eager: https://github.com/pytorch/pytorch/issues/125511 983 nintermediates = 6 984 else: 985 nintermediates = 4 986 987 elif optim_cls.__name__ in ["NAdam", "Adagrad", "RMSprop", "Adafactor"]: 988 # NAdam uses two intermediates at the same time (grads & exp_avg_sq_sqrt) 989 # Adagrad uses std and grads at the same time 990 # RMSprop uses avg and grads 991 # Adafactor uses row/col var and its mean 992 nintermediates = 2 993 994 if optim_cls.__name__ == "Adafactor" and kwargs.get("maximize", False): 995 # When maximize is True, Adafactor also tracks device_grad 996 nintermediates = 3 997 998 # Dynamo ST uses less mem than eager in the case of Adam/Adagrad/Nadam/RAdam 999 # which makes the foreach memory check fail 1000 if TEST_WITH_TORCHDYNAMO: 1001 st_max_mem += 6000 1002 1003 expected_max_mem = st_max_mem + intermediate_size * nintermediates 1004 # hipcc currently can't generate efficient code for the small buffer optimization 1005 # code path (see Note [small buffer optimization] for details), thus we always 1006 # dynamically allocate the tensor metadata for ROCM. Adjusting the expected max 1007 # memory usage to account for this. 1008 if TEST_WITH_ROCM: 1009 expected_max_mem *= 1.02 1010 1011 self.assertLessEqual(mt_max_mem, expected_max_mem) 1012 1013 @optims( 1014 [optim for optim in optim_db if "fused" in optim.supported_impls], 1015 dtypes=floating_types_and( 1016 torch.bfloat16, 1017 torch.float16, 1018 ), 1019 ) 1020 def test_fused_matches_forloop(self, device, dtype, optim_info): 1021 if _get_device_type(device) not in optim_info.supports_fused_on: 1022 self.skipTest( 1023 f"{device} is not supported for fused on {optim_info.optim_cls.__name__}" 1024 ) 1025 if _get_device_type(device) == "mps" and dtype not in ( 1026 torch.float16, 1027 torch.float32, 1028 ): 1029 self.skipTest("MPS supports only torch.float16 and torch.float32") 1030 self._test_derived_optimizers(device, dtype, optim_info, "fused") 1031 1032 @optims( 1033 [optim for optim in optim_db if "fused" in optim.supported_impls], 1034 dtypes=(torch.float32,), 1035 ) 1036 def test_fused_error_on_params_on_meta(self, device, dtype, optim_info): 1037 if _get_device_type(device) not in optim_info.supports_fused_on: 1038 self.skipTest( 1039 f"{device} is not supported for fused on {optim_info.optim_cls.__name__}" 1040 ) 1041 1042 with torch.device("meta"): 1043 model = torch.nn.Sequential( 1044 torch.nn.Linear(2, 3), 1045 torch.nn.Sigmoid(), 1046 torch.nn.Linear(3, 1), 1047 torch.nn.Sigmoid(), 1048 ).to(dtype) 1049 1050 optimizer = optim_info.optim_cls(model.parameters(), fused=True) 1051 with torch.device("meta"): 1052 for p in model.parameters(): 1053 p.grad = torch.rand_like(p) 1054 1055 with self.assertRaisesRegex( 1056 RuntimeError, 1057 "`fused=True` requires all the params to be floating point Tensors", 1058 ): 1059 optimizer.step() 1060 1061 optimizer.zero_grad(set_to_none=True) 1062 model.to_empty(device=device) 1063 for p in model.parameters(): 1064 p.grad = torch.rand_like(p) 1065 optimizer.step() 1066 1067 @onlyNativeDeviceTypes 1068 @largeTensorTest("64GB") 1069 @optims( 1070 [optim for optim in optim_db if "fused" in optim.supported_impls], 1071 dtypes=[torch.float16], 1072 ) 1073 def test_fused_large_tensor(self, device, dtype, optim_info): 1074 if device not in optim_info.supports_fused_on: 1075 self.skipTest( 1076 f"{device} is not supported for fused on {optim_info.optim_cls.__name__}" 1077 ) 1078 optim_cls = optim_info.optim_cls 1079 optim_inputs = optim_info.optim_inputs_func(device=device) 1080 for optim_input in optim_inputs: 1081 params = [torch.ones(2**32, device=device, dtype=dtype)] 1082 params[0].grad = torch.zeros_like(params[0]) 1083 optimizer = optim_cls(params, fused=True, **optim_input.kwargs) 1084 optimizer.step() 1085 1086 @onlyCUDA 1087 @optims( 1088 [optim for optim in optim_db if "fused" in optim.supported_impls], 1089 dtypes=[torch.float32], 1090 ) 1091 def test_fused_does_not_step_if_foundinf(self, device, dtype, optim_info): 1092 if device not in optim_info.supports_fused_on: 1093 self.skipTest( 1094 f"{device} is not supported for fused on {optim_info.optim_cls.__name__}" 1095 ) 1096 optim_cls = optim_info.optim_cls 1097 optim_inputs = optim_info.optim_inputs_func(device=device) 1098 num_params = 5 1099 for optim_input in optim_inputs: 1100 for no_grad_scale in (False, True): 1101 params = [ 1102 torch.ones((1,), device=device, dtype=dtype) 1103 for _ in range(num_params) 1104 ] 1105 params_c = [param.clone().detach() for param in params] 1106 for p in params: 1107 p.grad = torch.ones_like(p) 1108 optimizer = optim_cls(params, fused=True, **optim_input.kwargs) 1109 optimizer.grad_scale = ( 1110 None 1111 if no_grad_scale 1112 else torch.ones((1,), dtype=dtype, device=device) 1113 ) 1114 optimizer.found_inf = torch.ones((), dtype=dtype, device=device) 1115 optimizer.step() 1116 for p in params: 1117 if "step" in optimizer.state[p]: 1118 self.assertEqual( 1119 torch.zeros((), dtype=dtype, device=device), 1120 optimizer.state[p]["step"], 1121 ) 1122 self.assertEqual(params, params_c) 1123 1124 @parametrize("impl", ["fused", "capturable"]) 1125 @optims( 1126 [optim for optim in optim_db if "fused" in optim.supported_impls], 1127 dtypes=[torch.float32], 1128 ) 1129 def test_cpu_load_state_dict(self, device, dtype, impl, optim_info): 1130 # NOTE: This SIMULATES a fused/capturable optimizer with state moved to CPU, issue 103256 1131 # How do we get there? Users typically create CUDA models on fused optimizers and then 1132 # store checkpoints on CPU as CUDA memory is limited with torch.load(...map_location="cpu"). 1133 # Since this is a unit test, it is more expedient to simulate what the state_dict 1134 # would look like, which is basically CPU tensors with fused/capturable flag = True. 1135 optim_cls = optim_info.optim_cls 1136 opt_name = optim_cls.__name__ 1137 if opt_name in ("SGD", "Adagrad") and impl == "capturable": 1138 # Capturable SGD/Adagrad does not exist 1139 self.skipTest("SGD does not currently support capturable") 1140 if _get_device_type(device) == "cpu": 1141 self.skipTest("Test is only for non-cpu devices") 1142 elif ( 1143 impl == "fused" 1144 and _get_device_type(device) not in optim_info.supports_fused_on 1145 ): 1146 self.skipTest(f"{device} is not supported for fused on {opt_name}") 1147 elif impl == "capturable" and _get_device_type(device) == "mps": 1148 self.skipTest("MPS does not support capturable") 1149 1150 cpu_optim_inputs = optim_info.optim_inputs_func(device="cpu") 1151 for optim_input in cpu_optim_inputs: 1152 param = torch.tensor([0.1, 0.2], dtype=dtype, device="cpu") 1153 optimizer = optim_cls([param], **optim_input.kwargs) 1154 param.grad = torch.rand_like(param) 1155 optimizer.step() 1156 optim_state_dict_cpu = deepcopy(optimizer.state_dict()) 1157 optim_state_dict_cpu["param_groups"][0][impl] = True 1158 1159 # load 1160 optim_input.kwargs[impl] = True 1161 param_device = param.clone().detach().to(device=device) 1162 optimizer_device = optim_cls([param_device], **optim_input.kwargs) 1163 optimizer_device.load_state_dict(optim_state_dict_cpu) 1164 optimizer_device.zero_grad() 1165 param_device.grad = torch.rand_like(param_device) 1166 optimizer_device.step() 1167 1168 @optims(optim_db, dtypes=[torch.float32]) 1169 def test_param_groups_weight_decay(self, device, dtype, optim_info): 1170 optim_cls = optim_info.optim_cls 1171 # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 1172 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1173 device, dtype, optim_info, skip=("differentiable",) 1174 ) 1175 for optim_input in all_optim_inputs: 1176 weight_kwargs = optim_input.kwargs 1177 bias_kwargs = deepcopy(optim_input.kwargs) 1178 bias_kwargs["weight_decay"] = 0.0 1179 1180 weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) 1181 bias = Parameter(torch.randn((10), device=device, dtype=dtype)) 1182 input = torch.randn(5, device=device, dtype=dtype) 1183 1184 optimizer = optim_cls( 1185 [ 1186 dict(params=[weight], **weight_kwargs), 1187 dict(params=[bias], **bias_kwargs), 1188 ] 1189 ) 1190 1191 loss = (weight.mv(input) + bias).pow(2).sum() 1192 initial_value = loss.item() 1193 for _ in range(20): 1194 optimizer.zero_grad() 1195 loss = (weight.mv(input) + bias).pow(2).sum() 1196 loss.backward() 1197 if optim_info.only_supports_sparse_grads: 1198 # For this test, we naively convert the Tensor layout, which we know does 1199 # NOT represent the expected use case for optims like SparseAdam! 1200 weight.grad = weight.grad.to_sparse() 1201 bias.grad = bias.grad.to_sparse() 1202 optimizer.step() 1203 1204 # Test that the direction of loss moved appropriately 1205 if optim_input.kwargs.get("maximize", False): 1206 self.assertGreater(loss.item(), initial_value) 1207 else: 1208 self.assertLess(loss.item(), initial_value) 1209 1210 @optims(optim_db, dtypes=[torch.float32]) 1211 def test_param_groups_lr(self, device, dtype, optim_info): 1212 optim_cls = optim_info.optim_cls 1213 # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 1214 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1215 device, dtype, optim_info, skip=("differentiable",) 1216 ) 1217 for optim_input in all_optim_inputs: 1218 # optim_input.kwargs will be the param group kwargs, which should have >0 lr 1219 if "lr" not in optim_input.kwargs or optim_input.kwargs["lr"] == 0: 1220 optim_input.kwargs["lr"] = 1e-3 1221 outer_kwargs = {"lr": 1e-28} 1222 if optim_cls.__name__ == "Rprop": 1223 # Allow min step size to be 0 1224 outer_kwargs["step_sizes"] = (0, 50) 1225 1226 weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) 1227 bias = Parameter(torch.randn((10), device=device, dtype=dtype)) 1228 irrelevant = Parameter(torch.randn(2, device=device, dtype=dtype)) 1229 irrelevant_clone = irrelevant.clone() 1230 input = torch.randn(5, device=device, dtype=dtype) 1231 optimizer = optim_cls( 1232 [ 1233 dict(params=[weight, bias], **optim_input.kwargs), 1234 dict(params=[irrelevant]), 1235 ], 1236 **outer_kwargs, 1237 ) 1238 1239 loss = (weight.mv(input) + bias).pow(2).sum() 1240 initial_value = loss.item() 1241 for _ in range(20): 1242 optimizer.zero_grad() 1243 loss = (weight.mv(input) + bias).pow(2).sum() 1244 loss.backward() 1245 irrelevant.grad = torch.rand_like(irrelevant) 1246 if optim_info.only_supports_sparse_grads: 1247 # For this test, we naively convert the Tensor layout, which we know does 1248 # NOT represent the expected use case for optims like SparseAdam! 1249 weight.grad = weight.grad.to_sparse() 1250 bias.grad = bias.grad.to_sparse() 1251 irrelevant.grad = irrelevant.grad.to_sparse() 1252 optimizer.step() 1253 1254 # Test that the direction of loss moved appropriately 1255 if optim_input.kwargs.get("maximize", False): 1256 self.assertGreater(loss.item(), initial_value) 1257 else: 1258 self.assertLess(loss.item(), initial_value) 1259 1260 # Test that irrelevant parameters were not updated since lr was almost 0 1261 self.assertEqual(irrelevant, irrelevant_clone) 1262 1263 @optims(optim_db, dtypes=[torch.float32]) 1264 def test_step_is_noop_when_params_have_no_grad(self, device, dtype, optim_info): 1265 optim_cls = optim_info.optim_cls 1266 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1267 device, dtype, optim_info 1268 ) 1269 params = [ 1270 torch.randn(2, 3, requires_grad=False, device=device, dtype=dtype) 1271 for _ in range(2) 1272 ] 1273 old_params = [p.clone().detach() for p in params] 1274 1275 def closure(): 1276 return torch.tensor([1], device=device, dtype=dtype) 1277 1278 for optim_input in all_optim_inputs: 1279 optimizer = optim_cls(params, **optim_input.kwargs) 1280 optimizer.step(closure) 1281 1282 @optims(optim_db, dtypes=[torch.float32]) 1283 def test_step_is_noop_for_zero_grads(self, device, dtype, optim_info): 1284 optim_cls = optim_info.optim_cls 1285 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1286 device, dtype, optim_info 1287 ) 1288 param = torch.randn((5, 1), device=device, dtype=dtype, requires_grad=True) 1289 old_param = param.clone().detach() 1290 1291 def closure(): 1292 return torch.tensor([1], device=device, dtype=dtype) 1293 1294 for optim_input in all_optim_inputs: 1295 kwargs = optim_input.kwargs 1296 1297 # params will decay even if grads are empty if weight_decay != 0, 1298 # and capturable doesn't work for CPU tensors 1299 if kwargs.get("weight_decay", 0) != 0: 1300 continue 1301 1302 # AdamW params will be updated regardless of grads due to lr, so make lr smaller 1303 if optim_cls.__name__ == "AdamW": 1304 kwargs["lr"] = ( 1305 torch.tensor(1e-5) 1306 if isinstance(kwargs.get("lr", 1e-5), torch.Tensor) 1307 else 1e-5 1308 ) 1309 1310 if kwargs.get("differentiable", False): 1311 params = [param.clone()] 1312 else: 1313 params = [param] 1314 1315 optimizer = optim_cls(params, **kwargs) 1316 if optim_info.only_supports_sparse_grads: 1317 # Intentionally construct a multidimensional empty v for the sparse grad 1318 # Single dim v passes the test while multidim correctly repros the issue 1319 # https://github.com/pytorch/pytorch/issues/82486 1320 i = torch.empty((1, 0), device=device, dtype=dtype) 1321 v = torch.empty((0, 1), device=device, dtype=dtype) 1322 params[0].grad = torch.sparse_coo_tensor( 1323 i, v, (5, 1), device=device, dtype=dtype 1324 ) 1325 else: 1326 params[0].grad = torch.zeros_like(params[0]) 1327 optimizer.step(closure) 1328 self.assertEqual(old_param, params[0]) 1329 1330 @optims(optim_db, dtypes=[torch.float32]) 1331 def test_optimizer_can_be_printed(self, device, dtype, optim_info): 1332 optim_cls = optim_info.optim_cls 1333 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1334 device, dtype, optim_info 1335 ) 1336 params = [ 1337 Parameter(torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype)) 1338 for _ in range(2) 1339 ] 1340 for optim_input in all_optim_inputs: 1341 optimizer = optim_cls(params, **optim_input.kwargs) 1342 optimizer.__repr__() 1343 1344 @optims(optim_db, dtypes=[torch.float32]) 1345 def test_state_dict_deterministic(self, device, dtype, optim_info): 1346 optim_cls = optim_info.optim_cls 1347 1348 # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 1349 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1350 device, dtype, optim_info, skip=("differentiable",) 1351 ) 1352 weight = Parameter( 1353 torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype) 1354 ) 1355 bias = Parameter(torch.randn(2, requires_grad=True, device=device, dtype=dtype)) 1356 input = torch.randn(3, requires_grad=True, device=device, dtype=dtype) 1357 params = [weight, bias] 1358 1359 def fwd_bwd(optim, w, b, i): 1360 optim.zero_grad() 1361 loss = (w.mv(i) + b).pow(2).sum() 1362 loss.backward() 1363 if optim_info.only_supports_sparse_grads: 1364 if w.grad is not None: 1365 w.grad = w.grad.to_sparse() 1366 if b.grad is not None: 1367 b.grad = b.grad.to_sparse() 1368 return loss 1369 1370 for optim_input in all_optim_inputs: 1371 optimizer = optim_cls(params, **optim_input.kwargs) 1372 closure = functools.partial(fwd_bwd, optimizer, weight, bias, input) 1373 1374 # Prime the optimizer 1375 for _ in range(10): 1376 if optim_info.step_requires_closure: 1377 optimizer.step(closure) 1378 else: 1379 closure() 1380 optimizer.step() 1381 1382 # Clone the weights and construct a new optimizer for them 1383 with torch.no_grad(): 1384 weight_c = Parameter(weight.clone()) 1385 bias_c = Parameter(bias.clone()) 1386 1387 optimizer_c = optim_cls([weight_c, bias_c], **optim_input.kwargs) 1388 closure_c = functools.partial(fwd_bwd, optimizer_c, weight_c, bias_c, input) 1389 1390 # Load the state dict from the original optimizer into the new one 1391 optimizer_c.load_state_dict(deepcopy(optimizer.state_dict())) 1392 1393 # Run both optimizers in parallel 1394 for _ in range(10): 1395 if optim_info.step_requires_closure: 1396 optimizer.step(closure) 1397 optimizer_c.step(closure_c) 1398 else: 1399 closure() 1400 closure_c() 1401 optimizer.step() 1402 optimizer_c.step() 1403 1404 self.assertEqual(weight, weight_c) 1405 self.assertEqual(bias, bias_c) 1406 1407 # Make sure state dict is deterministic with equal (not identical) parameters 1408 self.assertEqual(optimizer.state_dict(), optimizer_c.state_dict()) 1409 1410 # Make sure repeated parameters have identical representation (see #36831) 1411 optimizer_c.param_groups.extend(optimizer_c.param_groups) 1412 self.assertEqual( 1413 optimizer.state_dict()["param_groups"][-1], 1414 optimizer_c.state_dict()["param_groups"][-1], 1415 ) 1416 1417 @optims(optim_db, dtypes=[torch.float32]) 1418 def test_can_load_older_state_dict(self, device, dtype, optim_info): 1419 optim_cls = optim_info.optim_cls 1420 1421 # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 1422 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1423 device, dtype, optim_info, skip=("differentiable",) 1424 ) 1425 for optim_input in all_optim_inputs: 1426 torch.manual_seed(1) 1427 model = torch.nn.Sequential( 1428 torch.nn.Conv2d(4, 2, 1, stride=2), 1429 torch.nn.BatchNorm2d(2, eps=1e-05, momentum=0.1), 1430 ) 1431 model.to(dtype=dtype, device=device) 1432 input = torch.rand(1, 4, 16, 16, device=device, dtype=dtype) 1433 optimizer = optim_cls(model.parameters(), **optim_input.kwargs) 1434 1435 def fwd_bwd(optim, mod, i): 1436 optim.zero_grad() 1437 loss = mod(i).sum() 1438 loss.backward() 1439 return loss 1440 1441 for _ in range(3): 1442 if optim_info.step_requires_closure: 1443 optimizer.step(functools.partial(fwd_bwd, optimizer, model, input)) 1444 else: 1445 fwd_bwd(optimizer, model, input) 1446 optimizer.step() 1447 1448 # old_state_dict has all new flags del'd 1449 old_state_dict = deepcopy(optimizer.state_dict()) 1450 old_state_dict_pg = old_state_dict["param_groups"] 1451 for group in old_state_dict_pg: 1452 for flag in optim_info.not_og_supported_flags: 1453 if flag in group: 1454 del group[flag] 1455 1456 optimizer.load_state_dict(old_state_dict) 1457 1458 # Make sure we can still step 1459 if optim_info.step_requires_closure: 1460 optimizer.step(functools.partial(fwd_bwd, optimizer, model, input)) 1461 else: 1462 fwd_bwd(optimizer, model, input) 1463 optimizer.step() 1464 1465 @optims(optim_db, dtypes=[torch.float32]) 1466 def test_save_load_equality_with_weights_only(self, device, dtype, optim_info): 1467 optim_cls = optim_info.optim_cls 1468 1469 # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 1470 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1471 device, dtype, optim_info, skip=("differentiable",) 1472 ) 1473 weight = Parameter( 1474 torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype) 1475 ) 1476 bias = Parameter(torch.randn(2, requires_grad=True, device=device, dtype=dtype)) 1477 input = torch.randn(3, requires_grad=True, device=device, dtype=dtype) 1478 params = [weight, bias] 1479 1480 def fwd_bwd(optim, w, b, i): 1481 optim.zero_grad() 1482 loss = (w.mv(i) + b).pow(2).sum() 1483 loss.backward() 1484 if optim_info.only_supports_sparse_grads: 1485 weight.grad = weight.grad.to_sparse() 1486 bias.grad = bias.grad.to_sparse() 1487 return loss 1488 1489 for optim_input in all_optim_inputs: 1490 optimizer = optim_cls(params, **optim_input.kwargs) 1491 closure = functools.partial(fwd_bwd, optimizer, weight, bias, input) 1492 1493 # Prime the optimizer 1494 for _ in range(3): 1495 optimizer.step(closure) 1496 1497 sd = optimizer.state_dict() 1498 1499 # === Check saved/loaded state_dict are the same (including weights_only load). === 1500 with tempfile.TemporaryFile() as f: 1501 torch.save(sd, f) 1502 f.seek(0) 1503 sd_copy = torch.load(f) 1504 self.assertEqual(sd_copy, sd) 1505 del sd_copy 1506 f.seek(0) 1507 sd_copy_wo = torch.load(f, weights_only=True) 1508 self.assertEqual(sd_copy_wo, sd) 1509 1510 @optims(optim_db, dtypes=[torch.float32]) 1511 def test_load_nontensor_step(self, device, dtype, optim_info): 1512 optim_cls = optim_info.optim_cls 1513 1514 # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 1515 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1516 device, dtype, optim_info, skip=("differentiable",) 1517 ) 1518 params = [ 1519 Parameter(torch.randn(2, 3, device=device, dtype=dtype)) for _ in range(2) 1520 ] 1521 for p in params: 1522 p.grad = torch.rand_like(p) 1523 if optim_info.only_supports_sparse_grads: 1524 # For this test, we naively convert the Tensor layout, which we know does 1525 # NOT represent the expected use case for optims like SparseAdam! 1526 p.grad = p.grad.to_sparse() 1527 1528 # Needed for second order optims like LBFGS 1529 closure_loss = torch.rand(1, device=device, dtype=dtype) 1530 1531 def closure(): 1532 return closure_loss if optim_info.step_requires_closure else None 1533 1534 for optim_input in all_optim_inputs: 1535 kwargs = optim_input.kwargs 1536 optimizer = optim_cls(params, **optim_input.kwargs) 1537 for _ in range(3): 1538 optimizer.step(closure) 1539 state_dict = deepcopy(optimizer.state_dict()) 1540 for p_state in state_dict["state"].values(): 1541 if "step" in p_state and torch.is_tensor(p_state["step"]): 1542 p_state["step"] = p_state["step"].item() 1543 optimizer.load_state_dict(state_dict) 1544 optimizer.step(closure) 1545 1546 @onlyCUDA 1547 @optims(optim_db, dtypes=[torch.float32]) 1548 def test_state_dict_with_cuda_params(self, device, dtype, optim_info): 1549 optim_cls = optim_info.optim_cls 1550 1551 # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 1552 # We limit our configs to CPU only, because we will be moving them to CUDA later 1553 cpu_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1554 "cpu", dtype, optim_info, skip=("differentiable",) 1555 ) 1556 1557 # Needed for second order optims like LBFGS 1558 closure_loss = torch.rand(1, device=device, dtype=dtype) 1559 1560 def closure(): 1561 return closure_loss if optim_info.step_requires_closure else None 1562 1563 for optim_input in cpu_optim_inputs: 1564 if ( 1565 "fused" in optim_input.kwargs 1566 and "cuda" not in optim_info.supports_fused_on 1567 ): 1568 self.skipTest( 1569 f"cuda is not supported for fused on {optim_cls.__name__}" 1570 ) 1571 params = [ 1572 Parameter(torch.randn(2, 3, device="cpu", dtype=dtype)) 1573 for _ in range(2) 1574 ] 1575 for p in params: 1576 p.grad = torch.randn_like(p) 1577 if optim_info.only_supports_sparse_grads: 1578 # For this test, we naively convert the Tensor layout, which we know does 1579 # NOT represent the expected use case for optims like SparseAdam! 1580 p.grad = p.grad.to_sparse() 1581 1582 optimizer = optim_cls(params, **optim_input.kwargs) 1583 1584 for _ in range(3): 1585 optimizer.step(closure) 1586 1587 with torch.no_grad(): 1588 params_cuda = [p.to(device="cuda") for p in params] 1589 for i, p in enumerate(params_cuda): 1590 p.grad = params[i].grad.to(device="cuda") 1591 optimizer_cuda = optim_cls(params_cuda, **optim_input.kwargs) 1592 1593 state_dict_cpu = deepcopy(optimizer.state_dict()) 1594 state_dict_cuda = deepcopy(optimizer.state_dict()) 1595 optimizer_cuda.load_state_dict(state_dict_cuda) 1596 1597 # Make sure state_dict_cuda isn't modified by merely calling load_state_dict 1598 self.assertEqual(state_dict_cpu, state_dict_cuda) 1599 1600 # Make sure that device of state['step'] is still CPU _unless_ torch.compile() added a capturable! 1601 capturable = state_dict_cpu["param_groups"][0].get("capturable", False) 1602 fused = state_dict_cpu["param_groups"][0].get("fused", False) 1603 new_state_dict = optimizer_cuda.state_dict() 1604 for state_cpu, state_cuda in zip( 1605 state_dict_cpu["state"].values(), new_state_dict["state"].values() 1606 ): 1607 if "step" in state_cpu and torch.is_tensor(state_cpu["step"]): 1608 self.assertEqual( 1609 state_cuda["step"].device.type, 1610 "cuda" if capturable or fused else "cpu", 1611 ) 1612 1613 for _ in range(5): 1614 optimizer.step(closure) 1615 optimizer_cuda.step(closure) 1616 self.assertEqual(params, params_cuda) 1617 self.assertEqual(optimizer.state_dict(), optimizer_cuda.state_dict()) 1618 1619 @staticmethod 1620 def _state_dict_pre_hook(optimizer: Optimizer) -> None: 1621 optimizer.state["test"] = 1 1622 1623 @staticmethod 1624 def _state_dict_post_hook( 1625 optimizer: Optimizer, state_dict: Dict[str, Any] 1626 ) -> Dict[str, Any]: 1627 if "test" in state_dict["state"]: 1628 state_dict["state"].pop("test") 1629 state_dict["ran_state_dict_pre_hook"] = True 1630 else: 1631 state_dict["ran_state_dict_pre_hook"] = False 1632 return state_dict 1633 1634 @optims(optim_db, dtypes=[torch.float32]) 1635 def test_state_dict_pre_hook(self, device, dtype, optim_info): 1636 optim_cls = optim_info.optim_cls 1637 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1638 device, dtype, optim_info 1639 ) 1640 for optim_input in all_optim_inputs: 1641 param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True) 1642 optim = optim_cls([param], **optim_input.kwargs) 1643 optim.register_state_dict_pre_hook(self.__class__._state_dict_pre_hook) 1644 state_dict = optim.state_dict() 1645 self.assertEqual(state_dict["state"]["test"], 1) 1646 1647 @optims(optim_db, dtypes=[torch.float32]) 1648 def test_state_dict_post_hook(self, device, dtype, optim_info): 1649 optim_cls = optim_info.optim_cls 1650 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1651 device, dtype, optim_info 1652 ) 1653 for optim_input in all_optim_inputs: 1654 param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True) 1655 optim = optim_cls([param], **optim_input.kwargs) 1656 optim.register_state_dict_post_hook(self.__class__._state_dict_post_hook) 1657 state_dict = optim.state_dict() 1658 self.assertFalse(state_dict["ran_state_dict_pre_hook"]) 1659 1660 @optims(optim_db, dtypes=[torch.float32]) 1661 def test_state_dict_pre_post_hook(self, device, dtype, optim_info): 1662 optim_cls = optim_info.optim_cls 1663 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1664 device, dtype, optim_info 1665 ) 1666 for optim_input in all_optim_inputs: 1667 param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True) 1668 optim = optim_cls([param], **optim_input.kwargs) 1669 optim.register_state_dict_pre_hook(self.__class__._state_dict_pre_hook) 1670 optim.register_state_dict_post_hook(self.__class__._state_dict_post_hook) 1671 state_dict = optim.state_dict() 1672 self.assertFalse("test" in state_dict["state"]) 1673 self.assertTrue(state_dict["ran_state_dict_pre_hook"]) 1674 1675 @staticmethod 1676 def _load_state_dict_pre_hook1( 1677 optimizer: Optimizer, state_dict: Dict[str, Any] 1678 ) -> None: 1679 state_dict["param_groups"][0]["lr"] = 0.002 1680 1681 @staticmethod 1682 def _load_state_dict_pre_hook2( 1683 optimizer: Optimizer, state_dict: Dict[str, Any] 1684 ) -> Dict[str, Any]: 1685 # The typical use case for returning a state dict is to drastically modify the state dict. 1686 # I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used 1687 my_state_dict = deepcopy(state_dict) 1688 my_state_dict["param_groups"][0]["lr"] = 0.003 1689 return my_state_dict 1690 1691 @staticmethod 1692 def _load_state_dict_post_hook(optimizer: Optimizer) -> None: 1693 optimizer.state["ran_load_state_dict_pre_hook2"] = ( 1694 optimizer.param_groups[0]["lr"] == 0.003 1695 ) 1696 optimizer.state["ran_load_state_dict_post_hook"] = True 1697 1698 @optims(optim_db, dtypes=[torch.float32]) 1699 def test_load_state_dict_pre_hook_and_prepend(self, device, dtype, optim_info): 1700 optim_cls = optim_info.optim_cls 1701 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1702 device, dtype, optim_info 1703 ) 1704 for optim_input in all_optim_inputs: 1705 param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True) 1706 optim = optim_cls([param], **optim_input.kwargs) 1707 state_dict = optim.state_dict() 1708 1709 # usually one would have a new optim instance here, but it's all the same here 1710 optim.register_load_state_dict_pre_hook( 1711 self.__class__._load_state_dict_pre_hook1 1712 ) 1713 optim.load_state_dict(state_dict) 1714 self.assertEqual(optim.param_groups[0]["lr"], 0.002) 1715 1716 optim.register_load_state_dict_pre_hook( 1717 self.__class__._load_state_dict_pre_hook2, prepend=True 1718 ) 1719 optim.load_state_dict(state_dict) 1720 # If prepend were False would be 0.003 but since prepend is True, the other hook overrides 1721 self.assertEqual(optim.param_groups[0]["lr"], 0.002) 1722 1723 @optims(optim_db, dtypes=[torch.float32]) 1724 def test_load_state_dict_post_hook(self, device, dtype, optim_info): 1725 optim_cls = optim_info.optim_cls 1726 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1727 device, dtype, optim_info 1728 ) 1729 for optim_input in all_optim_inputs: 1730 param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True) 1731 optim = optim_cls([param], **optim_input.kwargs) 1732 1733 optim.register_load_state_dict_post_hook( 1734 self.__class__._load_state_dict_post_hook 1735 ) 1736 optim.load_state_dict(optim.state_dict()) 1737 self.assertFalse(optim.state["ran_load_state_dict_pre_hook2"]) 1738 self.assertTrue(optim.state["ran_load_state_dict_post_hook"]) 1739 1740 @optims(optim_db, dtypes=[torch.float32]) 1741 def test_load_state_dict_pre_post_hook(self, device, dtype, optim_info): 1742 optim_cls = optim_info.optim_cls 1743 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1744 device, dtype, optim_info 1745 ) 1746 for optim_input in all_optim_inputs: 1747 param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True) 1748 optim = optim_cls([param], **optim_input.kwargs) 1749 1750 optim.register_load_state_dict_pre_hook( 1751 self.__class__._load_state_dict_pre_hook2 1752 ) 1753 optim.register_load_state_dict_post_hook( 1754 self.__class__._load_state_dict_post_hook 1755 ) 1756 optim.load_state_dict(optim.state_dict()) 1757 self.assertTrue(optim.state["ran_load_state_dict_pre_hook2"]) 1758 self.assertTrue(optim.state["ran_load_state_dict_post_hook"]) 1759 1760 @optims(optim_db, dtypes=[torch.float32]) 1761 def test_step_post_hook(self, device, dtype, optim_info): 1762 def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): 1763 nonlocal data 1764 data += 2 1765 1766 params = [torch.tensor([1, 1], device=device, dtype=dtype)] 1767 1768 def dummy_closure(): 1769 return 1 1770 1771 closure = dummy_closure if optim_info.step_requires_closure else None 1772 1773 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1774 device, dtype, optim_info 1775 ) 1776 for optim_input in all_optim_inputs: 1777 optim = optim_info.optim_cls(params, **optim_input.kwargs) 1778 data = 2 1779 hook_handle = optim.register_step_post_hook(post_hook) 1780 1781 optim.step(closure) 1782 optim.step(closure) 1783 # check if post hooks were registered 1784 self.assertEqual(data, 6) 1785 1786 # remove handles, take step and verify that hook is no longer registered 1787 hook_handle.remove() 1788 1789 optim.step(closure) 1790 self.assertEqual(data, 6) 1791 1792 @optims(optim_db, dtypes=[torch.float32]) 1793 def test_step_pre_hook(self, device, dtype, optim_info): 1794 def pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): 1795 nonlocal data 1796 data += 2 1797 1798 params = [torch.tensor([1, 1], device=device, dtype=dtype)] 1799 1800 def dummy_closure(): 1801 return 1 1802 1803 closure = dummy_closure if optim_info.step_requires_closure else None 1804 1805 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1806 device, dtype, optim_info 1807 ) 1808 for optim_input in all_optim_inputs: 1809 optim = optim_info.optim_cls(params, **optim_input.kwargs) 1810 data = 5 1811 hook_handle = optim.register_step_pre_hook(pre_hook) 1812 1813 optim.step(closure) 1814 optim.step(closure) 1815 # check if pre hooks were registered 1816 self.assertEqual(data, 9) 1817 1818 # remove handles, take step and verify that hook is no longer registered 1819 hook_handle.remove() 1820 1821 optim.step(closure) 1822 self.assertEqual(data, 9) 1823 1824 @optims(optim_db, dtypes=[torch.float32]) 1825 def test_step_all_hooks(self, device, dtype, optim_info): 1826 def global_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): 1827 nonlocal data 1828 data.append(0) 1829 1830 def global_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): 1831 nonlocal data 1832 data.append(5) 1833 1834 def local_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): 1835 nonlocal data 1836 data.append(1) 1837 1838 def local_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): 1839 nonlocal data 1840 data.append(2) 1841 1842 params = [torch.tensor([1, 1], device=device, dtype=dtype)] 1843 1844 def dummy_closure(): 1845 return 1 1846 1847 closure = dummy_closure if optim_info.step_requires_closure else None 1848 1849 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1850 device, dtype, optim_info 1851 ) 1852 for optim_input in all_optim_inputs: 1853 optim = optim_info.optim_cls(params, **optim_input.kwargs) 1854 optim2 = SGD(params) 1855 data = [] 1856 1857 # register global hooks to both optimizers 1858 global_pre_handle = register_optimizer_step_pre_hook(global_pre_hook) 1859 global_post_handle = register_optimizer_step_post_hook(global_post_hook) 1860 1861 # register local hooks 1862 first_pre_handle = optim.register_step_pre_hook(local_pre_hook) 1863 first_post_handle = optim.register_step_post_hook(local_post_hook) 1864 second_pre_handle = optim2.register_step_pre_hook(local_pre_hook) 1865 second_post_handle = optim2.register_step_post_hook(local_post_hook) 1866 1867 optim.step(closure) 1868 self.assertListEqual(data, [0, 1, 2, 5]) 1869 optim2.step(closure) 1870 self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5]) 1871 optim.step(closure) 1872 self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5]) 1873 1874 # remove all hooks 1875 global_pre_handle.remove() 1876 global_post_handle.remove() 1877 first_pre_handle.remove() 1878 first_post_handle.remove() 1879 second_pre_handle.remove() 1880 second_post_handle.remove() 1881 1882 optim.step(closure) 1883 optim2.step(closure) 1884 self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5]) 1885 1886 @optims(optim_db, dtypes=[torch.float32]) 1887 def test_deepcopy_copies_all_public_attrs(self, device, dtype, optim_info): 1888 optim_cls = optim_info.optim_cls 1889 1890 # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 1891 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1892 device, dtype, optim_info, skip=("differentiable",) 1893 ) 1894 1895 params = [ 1896 Parameter(torch.randn(2, 3, device=device, dtype=dtype)) for _ in range(2) 1897 ] 1898 for p in params: 1899 p.grad = torch.rand_like(p) 1900 if optim_info.only_supports_sparse_grads: 1901 # For this test, we naively convert the Tensor layout, which we know does 1902 # NOT represent the expected use case for optims like SparseAdam! 1903 p.grad = p.grad.to_sparse() 1904 1905 # Needed for second order optims like LBFGS 1906 def closure(): 1907 return 1 if optim_info.step_requires_closure else None 1908 1909 def getPublicAttrs(obj): 1910 return {k for k in obj.__dict__ if not k.startswith("_")} 1911 1912 for optim_input in all_optim_inputs: 1913 optimizer = optim_cls(params, **optim_input.kwargs) 1914 1915 # Make some state 1916 for _ in range(3): 1917 if optim_info.step_requires_closure: 1918 optimizer.step(closure) 1919 else: 1920 closure() 1921 optimizer.step() 1922 1923 self.assertEqual( 1924 getPublicAttrs(optimizer), getPublicAttrs(deepcopy(optimizer)) 1925 ) 1926 1927 @optims( 1928 [optim for optim in optim_db if optim.step_requires_closure], 1929 dtypes=[torch.float32], 1930 ) 1931 def test_second_order_optims_return_consistent_types( 1932 self, device, dtype, optim_info 1933 ): 1934 # Motivated by #7586 1935 optim_cls = optim_info.optim_cls 1936 params = [ 1937 torch.randn(10, 5, device=device, dtype=dtype), 1938 torch.randn(10, device=device, dtype=dtype), 1939 ] 1940 1941 def closure(): 1942 return torch.tensor([10], device=device, dtype=dtype) 1943 1944 for optim_input in optim_info.optim_inputs_func(device=device): 1945 # Currently, the only second order optim is LBFGS, so we just go ahead and modify 1946 # "tolerance_grad", but this may not scale if we add second order optims in the future 1947 kwargs = optim_input.kwargs 1948 kwargs["tolerance_grad"] = math.inf 1949 optim_inf = optim_cls(params, **kwargs) 1950 kwargs["tolerance_grad"] = -math.inf 1951 optim_neg_inf = optim_cls(params, **kwargs) 1952 1953 res1 = optim_inf.step(closure) 1954 res2 = optim_neg_inf.step(closure) 1955 self.assertEqual(type(res1), type(res2)) 1956 1957 @onlyCUDA 1958 @optims( 1959 [ 1960 optim 1961 for optim in optim_db 1962 if "cpu" in optim.supports_fused_on and "cuda" in optim.supports_fused_on 1963 ], 1964 dtypes=floating_types_and( 1965 torch.bfloat16, 1966 torch.float16, 1967 ), 1968 ) 1969 def test_fused_cpu_matches_cuda(self, device, dtype, optim_info): 1970 optim_cls = optim_info.optim_cls 1971 optim_inputs = optim_info.optim_inputs_func(device="cpu") 1972 for optim_input in optim_inputs: 1973 inpts, models, optimizers = [], [], [] 1974 for dev in ("cpu", "cuda"): 1975 kwargs = optim_input.kwargs 1976 kwargs["fused"] = True 1977 inpt = torch.tensor( 1978 [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=dev 1979 ).reshape(3, 2) 1980 1981 torch.manual_seed(1) 1982 model = torch.nn.Sequential( 1983 torch.nn.Linear(2, 3), 1984 torch.nn.Sigmoid(), 1985 torch.nn.Linear(3, 1), 1986 torch.nn.Sigmoid(), 1987 ) 1988 model.to(dtype=dtype, device=dev) 1989 1990 # foreach/fused optimizers should be tested with a 1991 # zero_size tensor as its last param. 1992 # ref: https://github.com/pytorch/pytorch/issues/100701 1993 empty_param = torch.empty( 1994 (), device=dev, dtype=dtype, requires_grad=True 1995 ) 1996 empty_param.grad = torch.rand_like(empty_param) 1997 params = list(model.parameters()) + [empty_param] 1998 1999 optimizer = optim_cls(params, **kwargs) 2000 inpts.append(inpt) 2001 models.append(model) 2002 optimizers.append(optimizer) 2003 self._compare_between(inpts, models, optimizers) 2004 2005 @onlyCUDA 2006 @optims( 2007 [ 2008 o 2009 for o in optim_db 2010 if ("foreach" in o.supported_impls and o.optim_cls.__name__ != "Adafactor") 2011 ], 2012 dtypes=[torch.float32], 2013 ) 2014 def test_defaults_changed_to_foreach(self, device, dtype, optim_info): 2015 # Test that the default implementations for optimizers are changed to foreach 2016 # except Adafactor, which defaults to the single tensor impl for memory efficiency. 2017 optim_cls = optim_info.optim_cls 2018 model = torch.nn.Linear(5, 5) 2019 model.to(dtype=dtype, device=device) 2020 inpt = torch.rand(2, 5, dtype=dtype, device=device) 2021 2022 import inspect 2023 2024 module = inspect.getmodule(optim_cls) 2025 2026 for optim_input in optim_info.optim_inputs_func(device=device): 2027 optim = optim_cls(model.parameters(), **optim_input.kwargs) 2028 optim.zero_grad() 2029 output = model(inpt) 2030 loss = output.sum() 2031 loss.backward() 2032 with patch.object( 2033 module, f"_multi_tensor_{optim_cls.__name__.lower()}" 2034 ) as mocked_foreach_impl: 2035 optim.step() 2036 self.assertTrue(mocked_foreach_impl.called) 2037 2038 @optims(optim_db, dtypes=[torch.float32]) 2039 def test_non_empty_state(self, device, dtype, optim_info): 2040 # There are internal tests that check that the state is not empty 2041 optim_cls = optim_info.optim_cls 2042 model = torch.nn.Linear(5, 5) 2043 model.to(dtype=dtype, device=device) 2044 inpt = torch.rand(2, 5, dtype=dtype, device=device) 2045 2046 for optim_input in optim_info.optim_inputs_func(device=device): 2047 optim = optim_cls(model.parameters(), **optim_input.kwargs) 2048 optim.zero_grad() 2049 output = model(inpt) 2050 loss = output.sum() 2051 loss.backward() 2052 2053 if optim_info.only_supports_sparse_grads: 2054 for param in model.parameters(): 2055 if param.grad is not None: 2056 param.grad = param.grad.to_sparse() 2057 2058 if optim_info.step_requires_closure: 2059 optim.step(lambda: 1.0) 2060 else: 2061 optim.step() 2062 2063 for state in optim.state.values(): 2064 self.assertGreater(len(state), 0) 2065 2066 2067instantiate_device_type_tests(TestOptimRenewed, globals(), allow_mps=True) 2068 2069 2070if __name__ == "__main__": 2071 run_tests() 2072